In [None]:
#!/usr/bin/env python3
"""
simple_rqa_processor.py

A simplified version of the RQA analysis pipeline for use in Jupyter notebooks.
Processes a single EEG file, single channel, and single window size.
"""

import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, Any, Tuple, List, Optional
import time
import gc

# Import the RQA processing function
# Make sure you have this module available in your environment
from pyddeeg.signal_processing.rqa_toolbox.rqe_parallelizable import process_single_channel_band

# Define EEG channel names for reference
EEG_CHANNELS = [
    "Fp1", "Fp2", "F7", "F3", "Fz", "F4", "F8", "FC5", "FC1", "FC2", "FC6", 
    "T7", "C3", "C4", "T8", "TP9", "CP5", "CP1", "CP2", "CP6", "TP10", 
    "P7", "P3", "Pz", "P4", "P8", "PO9", "O1", "Oz", "O2", "PO10", "Cz"
]

def get_channel_index(channel_name: str) -> int:
    """Get the index of a channel by name."""
    try:
        return EEG_CHANNELS.index(channel_name)
    except ValueError:
        raise ValueError(f"Channel '{channel_name}' not found. Available channels: {', '.join(EEG_CHANNELS)}")

def process_eeg_data(
    data: np.ndarray,
    channel_name: str = "Pz",
    band_idx: int = 0,
    window_size: int = 100,
    metric: str = "RR"
) -> np.ndarray:
    """
    Process EEG data for a single channel and frequency band.
    
    Parameters:
    -----------
    data : np.ndarray
        EEG data with shape (n_patients, n_channels, n_samples, n_bands)
    channel_name : str
        Name of the channel to process
    band_idx : int
        Index of the frequency band to analyze
    window_size : int
        Window size in ms for RQA computation
    metric : str
        RQA metric to compute (e.g., "RR", "DET", "ENT", "TT")
    
    Returns:
    --------
    np.ndarray
        Array of RQA metrics with shape (n_patients, n_windows)
    """
    print(f"Processing channel {channel_name} (band {band_idx}) with window size {window_size}ms")
    
    # Get channel index
    channel_idx = get_channel_index(channel_name)
    
    # Get data dimensions
    n_patients, n_channels, n_samples, n_bands = data.shape
    
    # Check if band index is valid
    if band_idx >= n_bands:
        raise ValueError(f"Band index {band_idx} out of range (only {n_bands} bands available)")
    
    # Extract band data for the specific channel
    band_data = data[:, channel_idx, :, band_idx]
    
    # Basic RQA parameters
    rqa_params = {
        "raw_signal_window_size": window_size,  # Window size in ms
        "stride": window_size // 2,             # Half the window size
        "embedding_dimension": 3,               # Dimension for phase space reconstruction
        "time_delay": 1,                        # Delay for phase space reconstruction
        "radius": 0.2,                          # Threshold for recurrence
        "metrics_to_use": [metric]              # Only compute one metric
    }
    
    # Calculate number of windows
    stride = rqa_params["stride"]
    num_windows = (n_samples - window_size) // stride + 1
    
    print(f"Will process {n_patients} patients with {num_windows} windows per patient")
    
    # Initialize output array
    results = np.zeros((n_patients, num_windows))
    
    # Process each patient
    start_time = time.time()
    for patient_idx in range(n_patients):
        # Get signal for this patient
        signal = band_data[patient_idx, :].copy()
        
        # Process channel
        rqa_result, _, _ = process_single_channel_band(
            signal=signal,
            rqa_params=rqa_params,
            normalize_metrics=True,
            return_rqe=False
        )
        
        # Store the result (first dimension is the metric)
        results[patient_idx, :] = rqa_result[:, 0]
        
        # Force garbage collection
        gc.collect()
        
        # Print progress
        if (patient_idx + 1) % 5 == 0 or patient_idx == n_patients - 1:
            elapsed = time.time() - start_time
            print(f"Processed {patient_idx + 1}/{n_patients} patients in {elapsed:.2f}s")
    
    return results

def plot_rqa_results(results: np.ndarray, title: str = "RQA Results"):
    """
    Plot RQA results for visual inspection.
    
    Parameters:
    -----------
    results : np.ndarray
        Array of RQA metrics with shape (n_patients, n_windows)
    title : str
        Title for the plot
    """
    n_patients = results.shape[0]
    
    plt.figure(figsize=(12, 8))
    
    # Plot mean across patients
    mean_values = np.nanmean(results, axis=0)
    plt.plot(mean_values, 'k-', label='Mean', linewidth=2)
    
    # Plot individual patients with transparency
    for i in range(min(n_patients, 10)):  # Plot at most 10 patients
        plt.plot(results[i, :], alpha=0.3)
    
    plt.title(title)
    plt.xlabel('Window Index')
    plt.ylabel('RQA Metric Value')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

def load_and_analyze_eeg(
    filepath: str,
    channel_name: str = "T7",
    band_idx: int = 4,
    window_size: int = 100,
    metric: str = "RR"
) -> np.ndarray:
    """
    Load EEG data and run RQA analysis.
    
    Parameters:
    -----------
    filepath : str
        Path to the .npz file containing EEG data
    channel_name : str
        Name of the channel to process
    band_idx : int
        Index of the frequency band to analyze
    window_size : int
        Window size in ms for RQA computation
    metric : str
        RQA metric to compute
    
    Returns:
    --------
    np.ndarray
        Array of RQA metrics with shape (n_patients, n_windows)
    """
    # Load the data
    print(f"Loading data from {filepath}")
    data = np.load(filepath)["data"]
    print(f"Data loaded with shape {data.shape}")
    
    # Process the data
    results = process_eeg_data(
        data=data, 
        channel_name=channel_name,
        band_idx=band_idx,
        window_size=window_size,
        metric=metric
    )
    
    # Plot results
    plot_rqa_results(
        results, 
        title=f"RQA {metric} for channel {channel_name}, band {band_idx}, window {window_size}ms"
    )
    
    return results

# Set parameters
file_path = './data/my_eeg_dataset.npz'  # Path to your data file
channel = 'Pz'                          # Channel to analyze
band = 0                                # Frequency band index (0 for Delta, etc.)
window_size = 100                       # Window size in ms
metric = 'RR'                           # RQA metric (RR = Recurrence Rate)

# Run the analysis
results = load_and_analyze_eeg(
    filepath=file_path,
    channel_name=channel,
    band_idx=band,
    window_size=window_size,
    metric=metric
)

# You can save the results if needed
np.save(f'rqa_{metric}_{channel}_band{band}_window{window_size}.npy', results)

a
