# VEST Database Integration Test

This notebook tests the integration of VEST_DB with MySQL database and comprehensive plotting using `plt.ion`.

## Test Objectives:
1. **VEST_DB Integration**: Load data from MySQL database using VEST_DB
2. **main_analysis.py Testing**: Verify main analysis workflow functionality
3. **interactive_analysis.py Testing**: Test interactive analysis pipeline
4. **Interactive Plotting**: Comprehensive evaluation of `plt.ion` functionality

## Features:
- VEST data loading from MySQL
- Combined visualization of NAS data and VEST data
- Interactive plotting with `plt.ion` and `interactive_plotting` context manager
- Testing of `main_analysis.py` and `interactive_analysis.py` workflows


## 1. Setup and Imports


In [None]:
# Setup and imports
import sys
from pathlib import Path
import os
import re

# Configure Numba threading layer for parallel execution
os.environ['NUMBA_THREADING_LAYER'] = 'tbb'

# Add project root to path
current_dir = Path.cwd()
ifi_root = current_dir.parent if current_dir.name == "analysis" else current_dir
sys.path.insert(0, str(ifi_root))

from ifi.utils.cache_setup import setup_project_cache
cache_config = setup_project_cache()
print(f"Cache configured: {cache_config['cache_dir']}")

import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt

# Import Numba config after setting environment variable
try:
    import numba
    try:
        numba.config.THREADING_LAYER = 'tbb'
        print(f"Numba threading layer: {numba.config.THREADING_LAYER}")
    except Exception as e:
        print(f"Warning: Could not set Numba threading layer: {e}")
        print("Falling back to default threading layer")
except ImportError:
    print("Warning: Numba not available")

# Import IFI modules
from ifi.db_controller.nas_db import NAS_DB
from ifi.db_controller.vest_db import VEST_DB
from ifi.analysis import processing, plots
from ifi.analysis.main_analysis import run_analysis
from ifi.analysis.interactive_analysis import create_mock_args
from ifi.utils.file_io import load_results_from_hdf5
from ifi.analysis.phi2ne import get_interferometry_params

print("✓ All imports successful")


## 2. Configuration


In [None]:
# Configuration
shot_num = 45821  # Change this to your shot number
config_path = "ifi/config.ini"  # Path to config file
results_base_dir = "results"  # Base directory for HDF5 results

# VEST field IDs to load (common fields: 109=Ip, 101=ne, etc.)
# Empty list [] loads all available fields
vest_fields = [109, 101]  # Example: Ip and ne fields

print(f"Configuration:")
print(f"  Shot number: {shot_num}")
print(f"  Config path: {config_path}")
print(f"  VEST fields: {vest_fields if vest_fields else 'All available'}")


## 3. Initialize Database Controllers


In [None]:
print("=" * 80)
print("Initializing Database Controllers")
print("=" * 80)

try:
    nas_db = NAS_DB(config_path=config_path)
    vest_db = VEST_DB(config_path=config_path)
    print("✓ Database controllers initialized")
except Exception as e:
    print(f"✗ Failed to initialize database controllers: {e}")
    raise


## 4. Load Data from NAS


In [None]:
print("=" * 80)
print(f"Loading NAS data for shot {shot_num}")
print("=" * 80)

# Helper function: Extract basename with extension from path (handles UNC paths)
def extract_basename(file_path: str) -> str:
    """Extract basename with extension, handling UNC paths and normalized separators."""
    normalized = file_path.replace("\\", "/")
    return normalized.split("/")[-1]

# Find files for the shot
target_files = nas_db.find_files(
    query=[shot_num],
    data_folders=None,
    add_path=False,
    force_remote=False,
)

nas_signals = {}
fs = 50e6  # Default sampling frequency
force_remote = False  # Set to True to bypass cache

if not target_files:
    print(f"⚠ No files found for shot {shot_num} in NAS")
    print("Trying to load from existing HDF5 results...")
    h5_results = load_results_from_hdf5(shot_num, base_dir=results_base_dir)
    if h5_results and "signals" in h5_results:
        nas_signals = h5_results["signals"]
        metadata = h5_results.get("metadata", {})
        fs = metadata.get("sampling_frequency", 50e6)
        print(f"✓ Loaded data from existing HDF5 file")
else:
    print(f"✓ Found {len(target_files)} file(s)")
    for f in target_files:
        print(f"  - {extract_basename(f)}")
    
    # Load and process each file
    loaded_count = 0
    for file_path in target_files:
        file_name = extract_basename(file_path)
        print(f"\nProcessing: {file_name}")
        
        df_raw = None
        
        # Handle UNC paths (won't work via SSH) - use basename for search
        if file_path.startswith("//") or file_path.startswith("\\\\"):
            print(f"  ⚠ UNC path detected - using basename for search")
            try:
                data_dict = nas_db.get_shot_data(
                    query=[file_name],
                    data_folders=None,
                    add_path=False,
                    force_remote=force_remote
                )
                # Find matching file by basename
                for key in data_dict.keys():
                    if extract_basename(key) == file_name:
                        df_raw = data_dict[key]
                        print(f"  ✓ Found matching file")
                        break
                if df_raw is None and data_dict:
                    df_raw = list(data_dict.values())[0]
                    print(f"  ⚠ Using first available file: {extract_basename(list(data_dict.keys())[0])}")
            except Exception as e:
                print(f"  ⚠ Error: {type(e).__name__}: {e}")
                continue
        else:
            # Normal path - direct loading
            try:
                data_dict = nas_db.get_shot_data(file_path, force_remote=force_remote)
                if data_dict and file_path in data_dict:
                    df_raw = data_dict[file_path]
                    print(f"  ✓ Loaded from NAS")
                else:
                    print(f"  ⚠ File not found in results, skipping...")
                    continue
            except Exception as e:
                print(f"  ⚠ Error: {type(e).__name__}: {e}")
                continue
        
        # Process loaded data
        if df_raw is not None:
            try:
                df_refined = processing.refine_data(df_raw)
                df_processed = processing.remove_offset(df_refined, window_size=2001)
                nas_signals[file_name] = df_processed
                loaded_count += 1
                
                # Calculate sampling frequency
                if "TIME" in df_processed.columns:
                    time_diff = df_processed["TIME"].diff().mean()
                    if pd.notna(time_diff) and time_diff > 0:
                        fs = 1 / time_diff
                print(f"  ✓ Processed: shape {df_processed.shape}, fs={fs/1e6:.1f} MHz")
            except Exception as e:
                print(f"  ⚠ Processing error: {type(e).__name__}: {e}")
                continue
    
    # Fallback to HDF5 if no files loaded
    if not nas_signals:
        print(f"\n⚠ No signals loaded ({loaded_count}/{len(target_files)} files processed)")
        print(f"  Attempting HDF5 fallback...")
        h5_results = load_results_from_hdf5(shot_num, base_dir=results_base_dir)
        if h5_results and "signals" in h5_results:
            nas_signals = h5_results["signals"]
            metadata = h5_results.get("metadata", {})
            fs = metadata.get("sampling_frequency", 50e6)
            print(f"✓ Loaded from HDF5 fallback")

# Summary
if nas_signals:
    print(f"\n✓ NAS data loading complete")
    print(f"  Signals: {list(nas_signals.keys())}")
    for signal_name, signal_df in nas_signals.items():
        print(f"    - {signal_name}: shape {signal_df.shape}, columns: {list(signal_df.columns)}")
    print(f"  Sampling frequency: {fs/1e6:.1f} MHz")
else:
    print(f"\n⚠ No NAS signals available")
    print(f"  Notebook will continue with VEST data only")


## 5. Load VEST Data from MySQL


In [None]:
print("=" * 80)
print(f"Loading VEST data for shot {shot_num} from MySQL")
print("=" * 80)

# Connect to VEST database
if not vest_db.connect():
    print("✗ Failed to connect to VEST database")
    vest_data = {}
else:
    print("✓ Connected to VEST database")
    
    # Load VEST data
    vest_data_dict = vest_db.load_shot(shot=shot_num, fields=vest_fields)
    
    if vest_data_dict:
        print(f"\n✓ Loaded VEST data from MySQL")
        print(f"  Available sampling rates: {list(vest_data_dict.keys())}")
        
        for rate, df in vest_data_dict.items():
            print(f"\n  Sampling rate: {rate}")
            print(f"    Shape: {df.shape}")
            print(f"    Columns: {list(df.columns)}")
            print(f"    Time range: {df.index.min():.6f} to {df.index.max():.6f} s")
            
            # Show field labels if available
            if hasattr(vest_db, 'field_labels') and vest_db.field_labels:
                print(f"    Field labels:")
                for col in df.columns:
                    if col in vest_db.field_labels:
                        print(f"      {col}: {vest_db.field_labels[col]}")
        
        vest_data = vest_data_dict
    else:
        print(f"⚠ No VEST data found for shot {shot_num}")
        vest_data = {}
    
    # Disconnect from database
    vest_db.disconnect()
    print("\n✓ Disconnected from VEST database")

print(f"\n✓ VEST data loading complete")


## 6. Test main_analysis.py Functionality


In [None]:
print("=" * 80)
print("Testing main_analysis.py Functionality")
print("=" * 80)

# Create mock args for main_analysis
from argparse import Namespace

analysis_args = Namespace(
    query=[str(shot_num)],
    data_folders=None,
    add_path=False,
    force_remote=False,
    results_dir="ifi/results",
    no_offset_removal=False,
    offset_window=2001,
    stft=False,  # Set to True to test STFT
    stft_cols=[],
    cwt=False,  # Set to True to test CWT
    cwt_cols=[],
    plot=False,  # Set to True to show plots
    no_plot_raw=False,
    no_plot_ft=False,
    downsample=10,
    trigger_time=0.290,
    density=False,  # Set to True to test density calculation
    vest_fields=vest_fields,
    baseline=None,
    save_plots=False,
    save_data=False,
    scheduler="threads",
)

print("\nAnalysis arguments configured:")
print(f"  Query: {analysis_args.query}")
print(f"  STFT: {analysis_args.stft}")
print(f"  CWT: {analysis_args.cwt}")
print(f"  Density: {analysis_args.density}")
print(f"  Plot: {analysis_args.plot}")
print(f"  VEST fields: {analysis_args.vest_fields}")

# Test run_analysis function
print("\n" + "=" * 80)
print("Running run_analysis function...")
print("=" * 80)

try:
    analysis_results = run_analysis(
        query=analysis_args.query,
        args=analysis_args,
        nas_db=nas_db,
        vest_db=vest_db,
    )
    
    if analysis_results:
        print("\n✓ Analysis completed successfully")
        print(f"  Processed shots: {list(analysis_results.keys())}")
        
        for shot_num_result, bundle in analysis_results.items():
            print(f"\n  Shot {shot_num_result}:")
            if "processed_data" in bundle:
                signals = bundle["processed_data"].get("signals")
                density = bundle["processed_data"].get("density")
                if signals is not None:
                    print(f"    Signals shape: {signals.shape}")
                if density is not None and not density.empty:
                    print(f"    Density shape: {density.shape}")
    else:
        print("⚠ Analysis returned no results")
        
except Exception as e:
    print(f"✗ Error during analysis: {e}")
    import traceback
    traceback.print_exc()


## 7. Test interactive_analysis.py Functionality


In [None]:
print("=" * 80)
print("Testing interactive_analysis.py Functionality")
print("=" * 80)

# Create mock args using interactive_analysis function
mock_args = create_mock_args()

# Modify for current shot
mock_args.query = [str(shot_num)]
mock_args.vest_fields = vest_fields
mock_args.stft = False  # Set to True to test STFT
mock_args.cwt = False  # Set to True to test CWT
mock_args.density = False  # Set to True to test density
mock_args.plot = False  # Set to True to show plots
mock_args.scheduler = "threads"

print("\nMock arguments from interactive_analysis:")
print(f"  Query: {mock_args.query}")
print(f"  STFT: {mock_args.stft}")
print(f"  CWT: {mock_args.cwt}")
print(f"  Density: {mock_args.density}")
print(f"  Plot: {mock_args.plot}")
print(f"  VEST fields: {mock_args.vest_fields}")

# Test run_analysis with mock args
print("\n" + "=" * 80)
print("Running run_analysis with interactive_analysis mock args...")
print("=" * 80)

try:
    interactive_results = run_analysis(
        query=mock_args.query,
        args=mock_args,
        nas_db=nas_db,
        vest_db=vest_db,
    )
    
    if interactive_results:
        print("\n✓ Interactive analysis completed successfully")
        print(f"  Processed shots: {list(interactive_results.keys())}")
    else:
        print("⚠ Interactive analysis returned no results")
        
except Exception as e:
    print(f"✗ Error during interactive analysis: {e}")
    import traceback
    traceback.print_exc()


In [None]:
def load_real_torch():
    """Safely load real torch module, handling dummy torch modules."""
    import sys
    import importlib
    
    # Check if torch is already loaded and is real (not dummy)
    if "torch" in sys.modules:
        existing_torch = sys.modules["torch"]
        # Check if it's a real torch module (has __version__ attribute)
        if hasattr(existing_torch, "__version__"):
            try:
                # Try to access version to confirm it's real
                _ = existing_torch.__version__
                print(f"✓ Real torch already loaded: version {existing_torch.__version__}")
                return existing_torch
            except (AttributeError, RuntimeError):
                pass  # Fall through to reload logic
    
    # Torch is either not loaded or is a dummy - remove it
    print("Removing dummy/stub torch modules...")
    modules_to_remove = []
    for name in list(sys.modules.keys()):
        if name == "torch" or name.startswith("torch."):
            modules_to_remove.append(name)
    
    for name in modules_to_remove:
        sys.modules.pop(name, None)
    
    # Invalidate import caches
    importlib.invalidate_caches()
    
    # Try to import real torch
    try:
        torch = importlib.import_module("torch")
        if hasattr(torch, "__version__"):
            print(f"✓ Successfully loaded real torch: version {torch.__version__}")
            return torch
        else:
            raise ImportError("Loaded torch module does not have __version__ attribute")
    except (ImportError, RuntimeError, AttributeError) as e:
        print(f"⚠ Error loading torch: {e}")
        print("  Note: Torch may already be partially loaded. Try restarting the kernel.")
        raise

# Load torch
try:
    torch = load_real_torch()
    print(f"torch.__version__: {torch.__version__}")
except Exception as e:
    print(f"✗ Failed to load torch: {e}")
    print("  If torch is already loaded elsewhere, you may need to restart the kernel.")
    torch = None

## 8. Comprehensive Interactive Plotting Test with plt.ion


In [None]:
print("=" * 80)
print("Comprehensive Interactive Plotting Test")
print("=" * 80)

# Helper function for downsampling data for plotting
def downsample_for_plot(time_data, signal_data, max_points=50000):
    """
    Downsample data if it exceeds max_points for faster plotting.
    
    Args:
        time_data: Time array
        signal_data: Signal array
        max_points: Maximum number of points to plot (default: 10000)
    
    Returns:
        Tuple of (downsampled_time, downsampled_signal, downsample_factor)
    """
    n_points = len(time_data)
    if n_points <= max_points:
        return time_data, signal_data, 1
    
    downsample_factor = max(1, n_points // max_points)
    downsampled_time = time_data[::downsample_factor]
    downsampled_signal = signal_data[::downsample_factor]
    
    return downsampled_time, downsampled_signal, downsample_factor

# Setup interactive mode
plots.setup_interactive_mode(backend="auto", style="default")
plt.ion()  # Turn on interactive mode
print("✓ Interactive mode enabled (plt.ion())")
print(f"  Backend: {matplotlib.get_backend()}")
print(f"  Interactive: {plt.isinteractive()}")

# Test 1: Using interactive_plotting context manager
print("\n" + "-" * 80)
print("Test 1: Using interactive_plotting context manager")
print("-" * 80)

# --- Figure 1 ---
if nas_signals:
    with plots.interactive_plotting(show_plots=True, block=False):
        signal_name = list(nas_signals.keys())[0]
        signal_df = nas_signals[signal_name]
        
        fig, ax = plt.subplots(figsize=(12, 6))
        
        if "TIME" in signal_df.columns:
            time = signal_df["TIME"].values
            signal_cols = [col for col in signal_df.columns if col != "TIME"]
        else:
            time = signal_df.index.values
            signal_cols = list(signal_df.columns)
        
        for col in signal_cols[:3]:  # Plot first 3 channels
            signal_data = signal_df[col].values
            time_ms = time * 1000
            
            # Downsample if needed
            time_plot, signal_plot, ds_factor = downsample_for_plot(time_ms, signal_data)
            
            if ds_factor > 1:
                print(f"  Note: Downsampling {col} by factor {ds_factor} for plotting")
            
            ax.plot(time_plot, signal_plot, label=col, alpha=0.7)
        
        ax.set_xlabel("Time [ms]")
        ax.set_ylabel("Amplitude [V]")
        ax.set_title(f"NAS Signals - Shot {shot_num} - {signal_name}")
        ax.legend()
        ax.grid(True, alpha=0.3)
        plt.tight_layout()
        
        print(f"✓ Created plot: NAS Signals")
else:
    print("⚠ Skipping NAS signal plot - no NAS data available")

# Test 2: Direct plt.ion() usage
print("\n" + "-" * 80)
print("Test 2: Direct plt.ion() usage")
print("-" * 80)

plt.ion()  # Ensure interactive mode

# Plot VEST data if available
# --- Figure 2 ---
if vest_data:
    for rate, df in vest_data.items():
        fig, ax = plt.subplots(figsize=(12, 6))
        
        time_base = df.index.values * 1000
        
        for col in df.columns:
            signal_data = df[col].values
            
            # Downsample if needed
            time_plot, signal_plot, ds_factor = downsample_for_plot(time_base, signal_data)
            
            if ds_factor > 1:
                print(f"  Note: Downsampling {col} by factor {ds_factor} for plotting")
            
            ax.plot(time_plot, signal_plot, label=col, alpha=0.7)
        
        ax.set_xlabel("Time [ms]")
        ax.set_ylabel("Amplitude")
        ax.set_title(f"VEST Data - Shot {shot_num} - {rate} sampling rate")
        ax.legend()
        ax.grid(True, alpha=0.3)
        plt.tight_layout()
        
        print(f"✓ Created plot: VEST Data ({rate})")
        
        # Only plot first sampling rate for brevity
        break

# Test 3: Combined NAS and VEST data visualization
# --- Figure 3 ---
print("\n" + "-" * 80)
print("Test 3: Combined NAS and VEST Data Visualization")
print("-" * 80)

if nas_signals and vest_data:
    fig, axes = plt.subplots(2, 1, figsize=(14, 10), sharex=True)
    
    # Plot NAS signals (top)
    signal_name = list(nas_signals.keys())[0]
    signal_df = nas_signals[signal_name]
    
    if "TIME" in signal_df.columns:
        time_nas = signal_df["TIME"].values
        signal_cols = [col for col in signal_df.columns if col != "TIME"]
    else:
        time_nas = signal_df.index.values
        signal_cols = list(signal_df.columns)
    
    for col in signal_cols[:2]:  # Plot first 2 channels
        signal_data = signal_df[col].values
        time_ms = time_nas * 1000
        
        # Downsample if needed
        time_plot, signal_plot, ds_factor = downsample_for_plot(time_ms, signal_data)
        
        if ds_factor > 1:
            print(f"  Note: Downsampling NAS {col} by factor {ds_factor} for plotting")
        
        axes[0].plot(time_plot, signal_plot, label=f"NAS: {col}", alpha=0.7)
    
    axes[0].set_ylabel("Amplitude [V]")
    axes[0].set_title(f"Combined Analysis - Shot {shot_num}")
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Plot VEST data (bottom)
    for rate, df in vest_data.items():
        time_base = df.index.values * 1000
        for col in df.columns:
            signal_data = df[col].values
            
            # Downsample if needed
            time_plot, signal_plot, ds_factor = downsample_for_plot(time_base, signal_data)
            
            if ds_factor > 1:
                print(f"  Note: Downsampling VEST {col} by factor {ds_factor} for plotting")
            
            axes[1].plot(time_plot, signal_plot, label=f"VEST ({rate}): {col}", alpha=0.7)
        break  # Only first sampling rate
    
    axes[1].set_xlabel("Time [ms]")
    axes[1].set_ylabel("Amplitude")
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    print(f"✓ Created combined plot: NAS + VEST Data")
elif vest_data:
    print("⚠ Skipping combined plot - NAS data not available")
    print("   VEST data is available and can be plotted separately (see Test 2)")
elif nas_signals:
    print("⚠ Skipping combined plot - VEST data not available")
else:
    print("⚠ Skipping combined plot - neither NAS nor VEST data available")

# Test 4: Using Plotter class with interactive mode
# --- Figure 4 ---
print("\n" + "-" * 80)
print("Test 4: Using Plotter class")
print("-" * 80)

if nas_signals:
    plotter = plots.Plotter()
    signal_name = list(nas_signals.keys())[0]
    signal_df = nas_signals[signal_name]
    
    # Downsample data if needed before passing to Plotter
    signal_df_plot = signal_df.copy()
    if "TIME" in signal_df.columns:
        time_base = signal_df["TIME"].values
        n_points = len(time_base)
        if n_points > 10000:
            ds_factor = max(1, n_points // 10000)
            signal_df_plot = signal_df.iloc[::ds_factor].copy()
            print(f"  Note: Downsampling data by factor {ds_factor} for Plotter class")
    elif len(signal_df) > 10000:
        ds_factor = max(1, len(signal_df) // 10000)
        signal_df_plot = signal_df.iloc[::ds_factor].copy()
        print(f"  Note: Downsampling data by factor {ds_factor} for Plotter class")
    
    with plots.interactive_plotting(show_plots=True, block=False):
        fig, ax = plotter.plot_waveforms(
            signal_df_plot,
            title=f"Plotter Class - Shot {shot_num} - {signal_name}",
            show_plot=True,
        )
        print(f"✓ Created plot using Plotter class")
else:
    print("⚠ Skipping Plotter class test - no NAS data available")

print("\n" + "=" * 80)
print("Interactive Plotting Tests Complete")
print("=" * 80)
print(f"\nTotal figures created: {len(plt.get_fignums())}")
print("\nNote: All plots are in interactive mode.")
print("Close plot windows or run plt.close('all') to clear them.")


## 9. Summary and Evaluation


In [None]:
print("=" * 80)
print("Test Summary and Evaluation")
print("=" * 80)

print("\n1. VEST_DB Integration:")
if vest_data:
    print("   ✓ VEST_DB successfully loaded data from MySQL")
    print(f"   ✓ Loaded {len(vest_data)} sampling rate group(s)")
    total_fields = sum(len(df.columns) for df in vest_data.values())
    print(f"   ✓ Total fields loaded: {total_fields}")
else:
    print("   ⚠ No VEST data loaded")

print("\n2. main_analysis.py Functionality:")
if 'analysis_results' in locals() and analysis_results:
    print("   ✓ run_analysis function executed successfully")
    print(f"   ✓ Processed {len(analysis_results)} shot(s)")
else:
    print("   ⚠ run_analysis did not return results")

print("\n3. interactive_analysis.py Functionality:")
if 'interactive_results' in locals() and interactive_results:
    print("   ✓ create_mock_args function works correctly")
    print("   ✓ Interactive analysis pipeline executed successfully")
else:
    print("   ⚠ Interactive analysis did not return results")

print("\n4. Interactive Plotting (plt.ion):")
print(f"   ✓ Interactive mode enabled: {plt.isinteractive()}")
print(f"   ✓ Backend: {matplotlib.get_backend()}")
print(f"   ✓ Figures created: {len(plt.get_fignums())}")
print("   ✓ interactive_plotting context manager works")
print("   ✓ Plotter class works with interactive mode")

print("\n5. Data Integration:")
if nas_signals and vest_data:
    print("   ✓ Successfully combined NAS and VEST data")
    print("   ✓ Combined visualization works correctly")
elif nas_signals:
    print("   ⚠ NAS data available but VEST data not loaded")
    print("   ✓ NAS data can be used independently")
elif vest_data:
    print("   ⚠ VEST data available but NAS data not loaded")
    print("   ✓ VEST data can be used independently")
    print("   Note: NAS data may be unavailable due to network/VPN/SSH tunnel issues")
else:
    print("   ⚠ Neither NAS nor VEST data available")
    print("   Note: For NAS data, ensure VPN/SSH tunnel is established")
    print("   Note: For VEST data, ensure database connection is available")

print("\n" + "=" * 80)
print("All Tests Complete")
print("=" * 80)

# Cleanup: Optionally close all figures
# Uncomment the line below to close all figures
# plt.close('all')
# plt.ioff()  # Turn off interactive mode
