# Stacked Analysis Plotting with Auto-Analysis

This notebook provides functionality to:
1. Check if results exist for a specified shot number
2. Run analysis automatically if required data (density, cwt, stft, etc.) is missing
3. Append new analysis results to existing HDF5 files
4. Create stacked plots with unified x-axis option

## Features
- **Auto-Analysis**: Automatically runs missing analysis components
- **Append Mode**: Adds new results to existing HDF5 files without overwriting
- **Stacked Plotting**: Stack multiple plot types (density, stft, cwt, signals) with shared x-axis
- **Flexible Configuration**: Select which plot types to display


In [None]:
import sys
from pathlib import Path

# Add project root to path
from ifi import IFI_ROOT
project_root = IFI_ROOT
sys.path.insert(0, str(project_root))

import numpy as np
import pandas as pd
import h5py
import matplotlib.pyplot as plt
from argparse import Namespace

from ifi.analysis.plots import Plotter
from ifi.utils.file_io import load_results_from_hdf5
from ifi.db_controller.nas_db import NAS_DB
from ifi.db_controller.vest_db import VEST_DB
from ifi.utils.common import LogManager
from ifi.analysis.main_analysis import run_analysis
from ifi.analysis.interactive_analysis import create_mock_args

# Initialize logging
LogManager(level="INFO")
logger = LogManager().get_logger(__name__)

print(f"Project root: {project_root}")
print(f"Python path: {sys.path[0]}")


## Configuration: Shot Number and Required Data Types

Specify the shot number and which data types you need for plotting.


In [None]:
# Configuration
SHOT_NUM = 45821  # Change this to your desired shot number

# Specify which data types are required
# Options: 'density', 'stft', 'cwt', 'signals', 'vest'
REQUIRED_DATA_TYPES = ['density', 'stft', 'cwt']  # Modify as needed

# Analysis options (used when running missing analysis)
ANALYSIS_OPTIONS = {
    'density': True,
    'stft': True,
    'cwt': True,
    'plot': False,  # Don't show plots during analysis
    'save_data': True,  # Save results to HDF5
    'save_plots': False,  # Don't save plots during analysis
    'scheduler': 'threads',  # Use threads for parallel processing
}

print(f"Target shot number: {SHOT_NUM}")
print(f"Required data types: {REQUIRED_DATA_TYPES}")


## Step 1: Check Existing Results

Check if results exist for the specified shot number and which data types are available.


In [None]:
def check_existing_results(shot_num: int, base_dir: str = None) -> dict:
    """
    Check what data types are available in existing results.
    
    Args:
        shot_num: Shot number to check
        base_dir: Base directory for results (default: ifi/results)
        
    Returns:
        dict: Dictionary with availability status for each data type
    """
    if base_dir is None:
        base_dir = str(project_root / "ifi" / "results")
    
    results_dir = Path(base_dir) / str(shot_num)
    h5_files = list(results_dir.glob("*.h5")) if results_dir.exists() else []
    
    availability = {
        'file_exists': len(h5_files) > 0,
        'h5_files': [str(f) for f in h5_files],
        'density': False,
        'stft': False,
        'cwt': False,
        'signals': False,
        'vest': False,
    }
    
    if not h5_files:
        return availability
    
    # Check each HDF5 file for available data types
    for h5_file in h5_files:
        try:
            with h5py.File(h5_file, "r") as f:
                if "density_data" in f and len(f["density_data"].keys()) > 0:
                    availability['density'] = True
                if "stft_results" in f and len(f["stft_results"].keys()) > 0:
                    availability['stft'] = True
                if "cwt_results" in f and len(f["cwt_results"].keys()) > 0:
                    availability['cwt'] = True
                if "signals" in f and not f["signals"].attrs.get("empty", False):
                    availability['signals'] = True
                if "vest_data" in f and len(f["vest_data"].keys()) > 0:
                    availability['vest'] = True
        except Exception as e:
            logger.warning(f"Error checking {h5_file}: {e}")
            continue
    
    return availability

# Check existing results
availability = check_existing_results(SHOT_NUM)

print(f"\nResults check for shot {SHOT_NUM}:")
print(f"  HDF5 file exists: {availability['file_exists']}")
if availability['file_exists']:
    print(f"  HDF5 files: {availability['h5_files']}")
print(f"\nAvailable data types:")
for data_type in ['density', 'stft', 'cwt', 'signals', 'vest']:
    status = "✓" if availability[data_type] else "✗"
    required = " (REQUIRED)" if data_type in REQUIRED_DATA_TYPES else ""
    print(f"  {status} {data_type}{required}")


In [None]:
def append_results_to_hdf5(
    output_dir: str,
    shot_num: int,
    signals: dict,
    stft_results: dict,
    cwt_results: dict,
    density_data: pd.DataFrame,
    vest_data: pd.DataFrame,
) -> str:
    """
    Append analysis results to existing HDF5 file, or create new one if it doesn't exist.
    
    This function adds new data to existing groups without overwriting existing data.
    """
    import h5py
    from ifi.utils.common import ensure_dir_exists
    
    # Determine filename
    if shot_num == 0 and signals is not None and signals:
        first_source_file = list(signals.keys())[0]
        filename = f"{Path(first_source_file).stem}.h5"
    else:
        filename = f"{shot_num}.h5"
    
    filepath = Path(output_dir) / filename
    ensure_dir_exists(str(output_dir))
    
    # Use 'a' mode to append (or create if doesn't exist)
    try:
        with h5py.File(filepath, "a") as hf:
            # Update or create metadata
            if "metadata" not in hf:
                metadata = hf.create_group("metadata")
            else:
                metadata = hf["metadata"]
            metadata.attrs["shot_number"] = shot_num
            metadata.attrs["updated_at"] = pd.Timestamp.now().isoformat()
            if "created_at" not in metadata.attrs:
                metadata.attrs["created_at"] = pd.Timestamp.now().isoformat()
            metadata.attrs["ifi_version"] = "1.0"
            
            # Append signals data
            if signals is not None and signals:
                if "signals" not in hf:
                    signals_group = hf.create_group("signals")
                else:
                    signals_group = hf["signals"]
                    # Remove empty flag if it exists
                    if "empty" in signals_group.attrs:
                        del signals_group.attrs["empty"]
                
                for signal_name, signal_data in signals.items():
                    if isinstance(signal_data, pd.DataFrame):
                        # Create or update signal group
                        if signal_name not in signals_group:
                            signal_group = signals_group.create_group(signal_name)
                        else:
                            signal_group = signals_group[signal_name]
                            # Delete existing datasets to replace them
                            for key in list(signal_group.keys()):
                                del signal_group[key]
                        
                        for col in signal_data.columns:
                            signal_group.create_dataset(col, data=signal_data[col].values)
            
            # Append STFT results
            if stft_results is not None and stft_results:
                if "stft_results" not in hf:
                    stft_group = hf.create_group("stft_results")
                else:
                    stft_group = hf["stft_results"]
                
                for signal_name, stft_data in stft_results.items():
                    if isinstance(stft_data, dict):
                        if signal_name not in stft_group:
                            signal_stft_group = stft_group.create_group(signal_name)
                        else:
                            signal_stft_group = stft_group[signal_name]
                            # Delete existing datasets/attrs to replace them
                            for key in list(signal_stft_group.keys()):
                                del signal_stft_group[key]
                            for key in list(signal_stft_group.attrs.keys()):
                                del signal_stft_group.attrs[key]
                        
                        for key, value in stft_data.items():
                            if isinstance(value, np.ndarray):
                                signal_stft_group.create_dataset(key, data=value)
                            elif isinstance(value, (int, float, str)):
                                signal_stft_group.attrs[key] = value
            
            # Append CWT results
            if cwt_results is not None and cwt_results:
                if "cwt_results" not in hf:
                    cwt_group = hf.create_group("cwt_results")
                else:
                    cwt_group = hf["cwt_results"]
                
                for signal_name, cwt_data in cwt_results.items():
                    if isinstance(cwt_data, dict):
                        if signal_name not in cwt_group:
                            signal_cwt_group = cwt_group.create_group(signal_name)
                        else:
                            signal_cwt_group = cwt_group[signal_name]
                            # Delete existing datasets/attrs to replace them
                            for key in list(signal_cwt_group.keys()):
                                del signal_cwt_group[key]
                            for key in list(signal_cwt_group.attrs.keys()):
                                del signal_cwt_group.attrs[key]
                        
                        for key, value in cwt_data.items():
                            if isinstance(value, np.ndarray):
                                signal_cwt_group.create_dataset(key, data=value)
                            elif isinstance(value, (int, float, str)):
                                signal_cwt_group.attrs[key] = value
            
            # Append density data
            if density_data is not None and not density_data.empty:
                if "density_data" not in hf:
                    density_group = hf.create_group("density_data")
                else:
                    density_group = hf["density_data"]
                    # Delete existing datasets to replace them
                    for key in list(density_group.keys()):
                        del density_group[key]
                
                for col in density_data.columns:
                    density_group.create_dataset(col, data=density_data[col].values)
            
            # Append VEST data
            if vest_data is not None and not vest_data.empty:
                if "vest_data" not in hf:
                    vest_group = hf.create_group("vest_data")
                else:
                    vest_group = hf["vest_data"]
                    # Delete existing datasets to replace them
                    for key in list(vest_group.keys()):
                        del vest_group[key]
                
                for col in vest_data.columns:
                    vest_group.create_dataset(col, data=vest_data[col].values)
        
        print(f"Results appended to: {filepath}")
        return str(filepath)
    
    except Exception as e:
        logger.error(f"Error appending results to HDF5: {e}")
        return None

# Determine which data types need to be generated
missing_data_types = [dt for dt in REQUIRED_DATA_TYPES if not availability[dt]]

if missing_data_types:
    print(f"\nMissing data types: {missing_data_types}")
    print("Running analysis to generate missing data...")
    
    try:
        nas_db = NAS_DB(config_path="ifi/config.ini")
        vest_db = VEST_DB(config_path="ifi/config.ini")
        
        # Create args for analysis
        args = create_mock_args()
        args.query = [str(SHOT_NUM)]
        args.density = 'density' in missing_data_types or ANALYSIS_OPTIONS.get('density', False)
        args.stft = 'stft' in missing_data_types or ANALYSIS_OPTIONS.get('stft', False)
        args.cwt = 'cwt' in missing_data_types or ANALYSIS_OPTIONS.get('cwt', False)
        args.plot = ANALYSIS_OPTIONS.get('plot', False)
        args.save_data = False  # We'll handle saving manually with append
        args.save_plots = ANALYSIS_OPTIONS.get('save_plots', False)
        args.scheduler = ANALYSIS_OPTIONS.get('scheduler', 'threads')
        
        # Run analysis
        results = run_analysis(
            query=args.query,
            args=args,
            nas_db=nas_db,
            vest_db=vest_db,
        )
        
        # Extract results and append to HDF5
        if results and str(SHOT_NUM) in results:
            shot_results = results[str(SHOT_NUM)]
            analysis_bundle = shot_results.get('analysis_results', {})
            
            # Extract data from analysis bundle
            signals_dict = analysis_bundle.get('signals', {})
            stft_results = analysis_bundle.get('stft_results', {})
            cwt_results = analysis_bundle.get('cwt_results', {})
            
            # Handle density data (may be dict keyed by frequency)
            density_data = analysis_bundle.get('density_data', pd.DataFrame())
            if isinstance(density_data, dict):
                # Combine all frequency density DataFrames
                if density_data:
                    first_freq = list(density_data.keys())[0]
                    combined_density = density_data[first_freq].copy()
                    for freq_key, freq_df in density_data.items():
                        if freq_key != first_freq and not freq_df.empty:
                            freq_df_reindexed = freq_df.reindex(
                                combined_density.index, method="nearest", limit=1
                            )
                            for col in freq_df_reindexed.columns:
                                combined_density[f"{freq_key}GHz_{col}"] = freq_df_reindexed[col]
                    density_data = combined_density
                else:
                    density_data = pd.DataFrame()
            
            vest_data = analysis_bundle.get('vest_data', pd.DataFrame())
            
            # Append to HDF5
            output_dir = str(project_root / "ifi" / "results" / str(SHOT_NUM))
            append_results_to_hdf5(
                output_dir,
                SHOT_NUM,
                signals_dict,
                stft_results,
                cwt_results,
                density_data,
                vest_data,
            )
            
            print("\nAnalysis completed and results appended to HDF5 file.")
            
            # Refresh availability check
            availability = check_existing_results(SHOT_NUM)
        else:
            print("\nWarning: Analysis did not return expected results.")
    
    except Exception as e:
        logger.error(f"Failed to run analysis: {e}")
        raise e
else:
    print("\nAll required data types are available. Skipping analysis.")


In [None]:
# Load results
base_dir = str(project_root / "ifi" / "results")
results = load_results_from_hdf5(SHOT_NUM, base_dir=base_dir)

if results:
    print(f"\nLoaded results for shot {SHOT_NUM}:")
    print(f"  Available keys: {list(results.keys())}")
    
    # Extract individual data types
    density_data = results.get('density_data', None)
    stft_results = results.get('stft_results', {})
    cwt_results = results.get('cwt_results', {})
    signals = results.get('signals', {})
    vest_data = results.get('vest_data', None)
    
    if density_data is not None:
        print(f"  Density data: shape {density_data.shape}, columns {list(density_data.columns)[:3]}...")
    if stft_results:
        print(f"  STFT results: {list(stft_results.keys())}")
    if cwt_results:
        print(f"  CWT results: {list(cwt_results.keys())}")
    if signals:
        print(f"  Signals: {list(signals.keys())}")
    if vest_data is not None:
        print(f"  VEST data: shape {vest_data.shape}")
else:
    print(f"\nNo results found for shot {SHOT_NUM}")
    density_data = None
    stft_results = {}
    cwt_results = {}
    signals = {}
    vest_data = None


## Step 4: Create Stacked Plots

Create stacked plots with optional unified x-axis. Select which plot types to display.


In [None]:
# Configuration for stacked plotting
PLOT_CONFIG = {
    'plot_density': True,  # Plot density data
    'plot_stft': True,     # Plot STFT spectrograms
    'plot_cwt': False,     # Plot CWT spectrograms
    'plot_signals': False, # Plot raw signals
    'plot_vest': False,    # Plot VEST data
    'unified_xaxis': True, # Use unified x-axis for all subplots
    'figsize': (14, 10),   # Figure size
}

def create_stacked_plots(
    results: dict,
    plot_config: dict,
    shot_num: int,
) -> tuple:
    """
    Create stacked plots with optional unified x-axis.
    
    Args:
        results: Dictionary containing loaded results
        plot_config: Configuration dictionary for plotting
        shot_num: Shot number for title
        
    Returns:
        tuple: (figure, axes) matplotlib objects
    """
    plotter = Plotter()
    
    # Determine which plots to create
    plot_list = []
    if plot_config.get('plot_density', False) and results.get('density_data') is not None:
        plot_list.append('density')
    if plot_config.get('plot_stft', False) and results.get('stft_results'):
        plot_list.append('stft')
    if plot_config.get('plot_cwt', False) and results.get('cwt_results'):
        plot_list.append('cwt')
    if plot_config.get('plot_signals', False) and results.get('signals'):
        plot_list.append('signals')
    if plot_config.get('plot_vest', False) and results.get('vest_data') is not None:
        plot_list.append('vest')
    
    if not plot_list:
        print("No plots to create based on configuration and available data.")
        return None, None
    
    n_plots = len(plot_list)
    figsize = plot_config.get('figsize', (14, 4 * n_plots))
    sharex = 'all' if plot_config.get('unified_xaxis', True) else None
    
    fig, axes = plt.subplots(n_plots, 1, figsize=figsize, sharex=sharex)
    if n_plots == 1:
        axes = [axes]
    
    # Extract common time axis if unified x-axis is requested
    common_time = None
    if plot_config.get('unified_xaxis', True):
        # Try to get time from density data first
        if results.get('density_data') is not None:
            density_df = results['density_data']
            if hasattr(density_df, 'index') and len(density_df.index) > 0:
                common_time = density_df.index.values
        # Fallback to signals
        if common_time is None and results.get('signals'):
            first_signal = list(results['signals'].values())[0]
            if 'TIME' in first_signal.columns:
                common_time = first_signal['TIME'].values
    
    plot_idx = 0
    
    # Plot density
    if 'density' in plot_list:
        density_data = results['density_data']
        time_data = density_data.index.values if hasattr(density_data, 'index') else None
        if time_data is None and common_time is not None:
            time_data = common_time[:len(density_data)]
        
        # Use plot_density with ax parameter if available
        try:
            fig_d, ax_d = plotter.plot_density(
                density_data,
                time_data=time_data,
                title=f"Shot {shot_num}: Density",
                show_plot=False,
            )
            # Copy plot to our subplot
            if ax_d is not None:
                ax_d.remove()
                axes[plot_idx].clear()
                # Replot on our axis
                for col in density_data.columns:
                    axes[plot_idx].plot(time_data, density_data[col].values, label=col)
                axes[plot_idx].set_title(f"Shot {shot_num}: Density")
                axes[plot_idx].set_ylabel("Density [m^-3]")
                axes[plot_idx].legend()
                axes[plot_idx].grid(True, alpha=0.3)
                plt.close(fig_d)
        except Exception as e:
            # Fallback: simple plotting
            for col in density_data.columns:
                axes[plot_idx].plot(time_data, density_data[col].values, label=col)
            axes[plot_idx].set_title(f"Shot {shot_num}: Density")
            axes[plot_idx].set_ylabel("Density [m^-3]")
            axes[plot_idx].legend()
            axes[plot_idx].grid(True, alpha=0.3)
        plot_idx += 1
    
    # Plot STFT
    if 'stft' in plot_list:
        stft_results = results['stft_results']
        # Plot first available STFT result
        if stft_results:
            first_stft_key = list(stft_results.keys())[0]
            stft_data = stft_results[first_stft_key]
            
            time_stft = stft_data.get('time', stft_data.get('time_STFT', None))
            freq_stft = stft_data.get('freq', stft_data.get('freq_STFT', None))
            stft_matrix = stft_data.get('stft_matrix', stft_data.get('STFT_matrix', None))
            
            if time_stft is None and common_time is not None:
                time_stft = common_time[:stft_matrix.shape[1]] if stft_matrix is not None else None
            
            if stft_matrix is not None and time_stft is not None and freq_stft is not None:
                im = axes[plot_idx].pcolormesh(
                    time_stft, freq_stft, np.abs(stft_matrix),
                    shading='gouraud', cmap='plasma'
                )
                axes[plot_idx].set_title(f"Shot {shot_num}: STFT ({first_stft_key})")
                axes[plot_idx].set_ylabel("Frequency [Hz]")
                plt.colorbar(im, ax=axes[plot_idx], label="Magnitude")
                axes[plot_idx].grid(True, alpha=0.3)
        plot_idx += 1
    
    # Plot CWT
    if 'cwt' in plot_list:
        cwt_results = results['cwt_results']
        # Plot first available CWT result
        if cwt_results:
            first_cwt_key = list(cwt_results.keys())[0]
            cwt_data = cwt_results[first_cwt_key]
            
            time_cwt = cwt_data.get('time', cwt_data.get('time_CWT', None))
            scales_cwt = cwt_data.get('scales', cwt_data.get('freq_CWT', None))
            cwt_matrix = cwt_data.get('cwt_matrix', cwt_data.get('CWT_matrix', None))
            
            if time_cwt is None and common_time is not None:
                time_cwt = common_time[:cwt_matrix.shape[1]] if cwt_matrix is not None else None
            
            if cwt_matrix is not None and time_cwt is not None and scales_cwt is not None:
                im = axes[plot_idx].pcolormesh(
                    time_cwt, scales_cwt, np.abs(cwt_matrix),
                    shading='gouraud', cmap='plasma'
                )
                axes[plot_idx].set_title(f"Shot {shot_num}: CWT ({first_cwt_key})")
                axes[plot_idx].set_ylabel("Scale")
                plt.colorbar(im, ax=axes[plot_idx], label="Magnitude")
                axes[plot_idx].grid(True, alpha=0.3)
        plot_idx += 1
    
    # Plot signals
    if 'signals' in plot_list:
        signals = results['signals']
        if signals:
            first_signal_key = list(signals.keys())[0]
            signal_df = signals[first_signal_key]
            
            if 'TIME' in signal_df.columns:
                time_sig = signal_df['TIME'].values
                for col in signal_df.columns:
                    if col != 'TIME':
                        axes[plot_idx].plot(time_sig, signal_df[col].values, label=col)
                axes[plot_idx].set_title(f"Shot {shot_num}: Signals ({first_signal_key})")
                axes[plot_idx].set_ylabel("Amplitude [V]")
                axes[plot_idx].legend()
                axes[plot_idx].grid(True, alpha=0.3)
        plot_idx += 1
    
    # Plot VEST data
    if 'vest' in plot_list:
        vest_data = results['vest_data']
        if vest_data is not None:
            time_vest = vest_data.index.values if hasattr(vest_data, 'index') else None
            if time_vest is None and common_time is not None:
                time_vest = common_time[:len(vest_data)]
            
            # Plot first VEST column
            if len(vest_data.columns) > 0:
                first_col = vest_data.columns[0]
                axes[plot_idx].plot(time_vest, vest_data[first_col].values, label=first_col)
                axes[plot_idx].set_title(f"Shot {shot_num}: VEST Data")
                axes[plot_idx].set_ylabel(first_col)
                axes[plot_idx].legend()
                axes[plot_idx].grid(True, alpha=0.3)
                plot_idx += 1
    
    # Set x-axis label on bottom plot only if unified
    if plot_config.get('unified_xaxis', True) and plot_idx > 0:
        axes[-1].set_xlabel("Time [s]")
    
    plt.tight_layout()
    
    return fig, axes

# Create stacked plots
if results:
    fig, axes = create_stacked_plots(results, PLOT_CONFIG, SHOT_NUM)
    if fig is not None:
        plt.show()
        print("\nStacked plots created successfully.")
    else:
        print("\nCould not create plots. Check data availability and plot configuration.")
else:
    print("\nNo results available for plotting.")


## Summary

This notebook provides:
1. ✓ Automatic checking of existing results for specified shot number
2. ✓ Automatic analysis execution for missing data types
3. ✓ Append mode for adding new results to existing HDF5 files
4. ✓ Stacked plotting with unified x-axis option

### Usage Tips
- Modify `SHOT_NUM` to analyze different shots
- Adjust `REQUIRED_DATA_TYPES` to specify which analyses are needed
- Configure `PLOT_CONFIG` to select which plots to display
- Set `unified_xaxis=True` to align all plots on the same time axis
