In [30]:
"""
Phase 1: Preprocessing Pipeline for PhysioNet Motor Movement/Imagery Database
Focuses on Task 1 (left hand) and Task 2 (right hand) motor imagery in beta band.
"""

import numpy as np
import mne
from mne.io import read_raw_edf
from scipy import signal
from scipy.signal import butter, filtfilt, iirnotch
import os
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')


class PhysioNetPreprocessor:
    """
    Preprocessing pipeline for PhysioNet Motor Imagery Database.
    
    Pipeline Steps:
    1. Channel Selection (9 sensorimotor channels)
    2. Temporal Filtering (13-30 Hz beta band - PD primary biomarker)
    3. Spatial Filtering (Common Average Reference)
    4. Artifact Removal (±100 μV threshold)
    5. Epoching (2-second windows before movement onset)
    6. Baseline Correction
    """
    
    def __init__(self, data_dir=r"C:\Users\fibof\Downloads\physionet"):
        """
        Initialize preprocessor.
        
        Parameters:
        -----------
        data_dir : str
            Path to PhysioNet dataset directory
        """
        self.data_dir = Path(data_dir)
        
        # Channel configuration - sensorimotor channels over motor cortex
        # PhysioNet uses standard 10-20 system with these available channels
        # We'll select C3, Cz, C4 (primary motor cortex) and surrounding channels
        self.sensorimotor_channels = ['FC3..', 'FC1..', 'FCz..', 'FC2..', 'FC4..',
                                      'C5..', 'C3..', 'C1..', 'Cz..', 'C2..', 'C4..', 'C6..',
                                      'CP3..', 'CP1..', 'CPz..', 'CP2..', 'CP4..']
        # We'll let MNE find which ones exist and select the best available
        
        # Beta band filtering (13-30 Hz) - PD primary biomarker frequency
        self.sfreq = 160  # PhysioNet sampling frequency
        self.lowcut = 0.05   # Beta band lower bound
        self.highcut = 45.0  # Beta band upper bound
        self.notch_freq = 60.0  # Remove electrical line noise
        
        # Epoching parameters
        self.epoch_tmin = -2.0  # 2 seconds before movement onset
        self.epoch_tmax = 0.0   # Movement onset
        self.baseline_tmin = -2.0  # Baseline must be within epoch window
        self.baseline_tmax = -1.5  # Use first 0.5 seconds as baseline
        
        # Artifact rejection threshold
        self.reject_threshold = 100e-6  # 100 μV
        
    def get_subject_files(self, subject_id):
        """
        Get file paths for a specific subject's task 1 and task 2.
        
        Task 1 (Motor Imagery):
        - Runs 3, 7, 11: Left vs Right hand motor imagery
        
        Task 2 (Motor Execution):
        - Runs 4, 8, 12: Left vs Right hand actual movement
        
        Parameters:
        -----------
        subject_id : int
            Subject ID (1-50)
            
        Returns:
        --------
        files : list
            List of (edf_path, event_path) tuples
        """
        subject_str = f'S{subject_id:03d}'
        
        # Task 1: Motor imagery runs (3, 7, 11)
        # Task 2: Motor execution runs (4, 8, 12)
        runs = [3, 7, 11, 4, 8, 12]  # Both tasks
        
        files = []
        for run in runs:
            edf_file = self.data_dir / f'{subject_str}R{run:02d}.edf'
            event_file = self.data_dir / f'{subject_str}R{run:02d}.edf.event'
            
            if edf_file.exists() and event_file.exists():
                files.append((edf_file, event_file))
            else:
                print(f"Warning: Missing files for {subject_str}R{run:02d}")
        
        return files
    
    def load_subject_data(self, subject_id):
        """
        Load EEG data for a specific subject (tasks 1 and 2).
        
        Task 1: Motor imagery (runs 3, 7, 11)
        Task 2: Motor execution (runs 4, 8, 12)
        
        Parameters:
        -----------
        subject_id : int
            Subject ID (1-50)
            
        Returns:
        --------
        raw_list : list
            List of (raw, event_file_path) tuples
        """
        files = self.get_subject_files(subject_id)
        
        if not files:
            raise FileNotFoundError(f"No data found for subject {subject_id}")
        
        raw_list = []
        for edf_path, event_path in files:
            # Load EDF file
            raw = read_raw_edf(edf_path, preload=True, verbose=False)
            raw_list.append((raw, event_path))
        
        return raw_list
    
    def select_channels(self, raw):
        """
        Select sensorimotor channels covering motor cortex.
        
        PhysioNet dataset has 64 channels with names like "FC3..", "C3..", etc.
        We want channels over the motor cortex (C3, Cz, C4 and surrounding).
        
        Parameters:
        -----------
        raw : mne.io.Raw
            Raw EEG data
            
        Returns:
        --------
        raw : mne.io.Raw
            Data with selected channels only
        """
        # Get all available channel names
        available_channels = raw.ch_names
        print(f"Available channels: {available_channels[:10]}... (showing first 10)")
        
        # Define desired sensorimotor channels (in priority order)
        # Primary motor cortex and surrounding areas
        desired_channels = [
            'C3..', 'Cz..', 'C4..',      # Primary motor cortex (most important)
            'C1..', 'C2..', 'C5..', 'C6..',  # Additional central channels
            'FC3..', 'FCz..', 'FC4..',   # Frontal-central
            'FC1..', 'FC2..',
            'CP3..', 'CPz..', 'CP4..',   # Central-parietal
            'CP1..', 'CP2..'
        ]
        
        # Find which desired channels are actually available
        selected = []
        for ch in desired_channels:
            if ch in available_channels:
                selected.append(ch)
        
        # If we have fewer than 9 channels, add more central/motor channels
        if len(selected) < 9:
            # Add any other C or CP channels we might have missed
            for ch in available_channels:
                if (ch.startswith('C') or ch.startswith('FC') or ch.startswith('CP')) and ch not in selected:
                    selected.append(ch)
                    if len(selected) >= 9:
                        break
        
        # Take the first 9 channels (or all if less than 9)
        selected = selected[:9]
        
        if len(selected) == 0:
            raise ValueError("No sensorimotor channels found in the data!")
        
        # Select these channels
        raw.pick_channels(selected, ordered=True)
        
        print(f"Selected {len(raw.ch_names)} channels: {raw.ch_names}")
        return raw
    
    def apply_temporal_filtering(self, raw):
        """
        Apply beta band filter (13-30 Hz) and notch filter (60 Hz).
        
        Beta band is the primary biomarker for Parkinson's disease.
        
        Parameters:
        -----------
        raw : mne.io.Raw
            Raw EEG data
            
        Returns:
        --------
        raw : mne.io.Raw
            Filtered data
        """
        # Bandpass filter: 13-30 Hz (beta band - PD primary biomarker)
        raw.filter(self.lowcut, self.highcut, fir_design='firwin', 
                   verbose=False)
        
        # Notch filter: 60 Hz (remove electrical line noise)
        raw.notch_filter(self.notch_freq, verbose=False)
        
        print(f"Applied beta band filter: {self.lowcut}-{self.highcut} Hz")
        print(f"Applied notch filter: {self.notch_freq} Hz")
        
        return raw
    
    def apply_spatial_filtering(self, raw):
        """
        Apply Common Average Reference (CAR) to reduce common noise.
        
        Parameters:
        -----------
        raw : mne.io.Raw
            Raw EEG data
            
        Returns:
        --------
        raw : mne.io.Raw
            Spatially filtered data
        """
        raw.set_eeg_reference('average', projection=False, verbose=False)
        print("Applied Common Average Reference (CAR)")
        return raw
    
    def read_event_file(self, event_file_path):
        """
        Read events from .edf.event file.
        
        PhysioNet .edf.event files are in EDF+ Annotations format (binary).
        We need to parse the annotations from the EDF file itself.
        
        Parameters:
        -----------
        event_file_path : Path
            Path to .edf.event file (actually the EDF file with annotations)
            
        Returns:
        --------
        events_list : list
            List of (time, duration, event_type) tuples
        """
        events_list = []
        
        try:
            # The .edf.event file is actually an EDF file with annotations
            # Let's read it using MNE
            raw_event = read_raw_edf(str(event_file_path), preload=False, verbose=False)
            
            # Extract annotations
            annotations = raw_event.annotations
            
            for annot in annotations:
                onset = annot['onset']
                duration = annot['duration']
                description = annot['description']
                
                # The description should be T0, T1, or T2
                if description in ['T0', 'T1', 'T2']:
                    events_list.append((onset, duration, description))
            
        except Exception as e:
            print(f"Warning: Could not read event file as EDF: {e}")
            print("Attempting to parse as text file...")
            
            # Fallback: try to read as text
            try:
                with open(event_file_path, 'r', encoding='latin-1', errors='ignore') as f:
                    content = f.read()
                    # Look for patterns like "T0", "T1", "T2"
                    import re
                    # This is a simplified parser - may need adjustment
                    lines = content.split('\n')
                    for line in lines:
                        if 'T0' in line or 'T1' in line or 'T2' in line:
                            # Try to extract the event type
                            if 'T0' in line:
                                events_list.append((0, 0, 'T0'))
                            elif 'T1' in line:
                                events_list.append((0, 0, 'T1'))
                            elif 'T2' in line:
                                events_list.append((0, 0, 'T2'))
            except:
                pass
        
        return events_list
    
    def events_to_mne_format(self, events_list, sfreq):
        """
        Convert event list to MNE events array format.
        
        Parameters:
        -----------
        events_list : list
            List of (time, duration, event_type) tuples
        sfreq : float
            Sampling frequency
            
        Returns:
        --------
        events : np.ndarray
            MNE events array (n_events, 3) with [sample, 0, event_id]
        event_id : dict
            Mapping of event names to IDs
        """
        event_id = {
            'rest': 1,
            'left': 2,
            'right': 3
        }
        
        # Map event types to IDs
        event_type_to_id = {
            'T0': 1,  # rest
            'T1': 2,  # left
            'T2': 3   # right
        }
        
        events = []
        for time_offset, duration, event_type in events_list:
            # Convert time to sample number
            sample = int(time_offset * sfreq)
            
            # Get event ID
            if event_type in event_type_to_id:
                event_code = event_type_to_id[event_type]
                events.append([sample, 0, event_code])
        
        return np.array(events, dtype=int), event_id
    def extract_epochs(self, raw, event_file_path):
        """
        Extract 2-second epochs before movement onset.
        
        Try multiple methods to extract events:
        1. From annotations in the main EDF file
        2. From the separate .edf.event file
        
        PhysioNet event codes:
        - T0: Rest (baseline)
        - T1: Left fist motor imagery/execution
        - T2: Right fist motor imagery/execution
        
        Parameters:
        -----------
        raw : mne.io.Raw
            Raw EEG data
        event_file_path : Path
            Path to corresponding .edf.event file
            
        Returns:
        --------
        epochs : mne.Epochs
            Epoched data with labels
        """
        # Method 1: Try to get events from annotations in the main EDF file
        print(f"Attempting to extract events from EDF annotations...")
        
        if raw.annotations is not None and len(raw.annotations) > 0:
            print(f"  Found {len(raw.annotations)} annotations in EDF file")
            events, event_id = mne.events_from_annotations(raw, verbose=False)
            
            # Map the event IDs to our labels
            # MNE creates numeric IDs, we need to map them back
            new_event_id = {}
            event_id_rev = {v: k for k, v in event_id.items()}
            
            for eid, name in event_id_rev.items():
                if 'T0' in name:
                    new_event_id['rest'] = eid
                elif 'T1' in name:
                    new_event_id['left'] = eid
                elif 'T2' in name:
                    new_event_id['right'] = eid
            
            event_id = new_event_id
            
            if len(event_id) > 0:
                print(f"  Successfully extracted events from annotations")
                print(f"  Event mapping: {event_id}")
            else:
                print(f"  No T0/T1/T2 events found in annotations")
                events = None
        else:
            print(f"  No annotations found in EDF file")
            events = None
        
        # Method 2: If no events from annotations, try the event file
        if events is None or len(events) == 0:
            print(f"Reading events from separate event file: {event_file_path.name}")
            events_list = self.read_event_file(event_file_path)
            
            if len(events_list) > 0:
                events, event_id = self.events_to_mne_format(events_list, raw.info['sfreq'])
                print(f"  Successfully read {len(events)} events from event file")
            else:
                raise ValueError("Could not extract events from either EDF annotations or event file")
        
        print(f"Found {len(events)} events")
        print(f"Event mapping: {event_id}")
        
        # Count events by type
        unique_events, counts = np.unique(events[:, 2], return_counts=True)
        event_id_rev = {v: k for k, v in event_id.items()}
        for event_code, count in zip(unique_events, counts):
            event_name = event_id_rev.get(event_code, f'unknown_{event_code}')
            print(f"  {event_name}: {count} events")
        
        # Create epochs
        epochs = mne.Epochs(
            raw, 
            events, 
            event_id,
            tmin=self.epoch_tmin,
            tmax=self.epoch_tmax,
            baseline=(self.baseline_tmin, self.baseline_tmax),
            preload=True,
            reject=None,  # We'll do manual artifact rejection
            verbose=False
        )
        
        print(f"Extracted {len(epochs)} epochs")
        print(f"Epoch window: {self.epoch_tmin} to {self.epoch_tmax} s")
        print(f"Baseline: {self.baseline_tmin} to {self.baseline_tmax} s")
        
        return epochs
    
    def reject_artifacts(self, epochs):
        """
        Remove trials exceeding ±100 μV threshold.
        
        Parameters:
        -----------
        epochs : mne.Epochs
            Epoched data
            
        Returns:
        --------
        epochs : mne.Epochs
            Clean epochs
        n_rejected : int
            Number of rejected epochs
        """
        n_before = len(epochs)
        
        # Get data
        data = epochs.get_data()  # Shape: (n_epochs, n_channels, n_times)
        
        # Find epochs exceeding threshold
        max_vals = np.abs(data).max(axis=(1, 2))
        good_epochs = max_vals < self.reject_threshold
        
        # Keep only good epochs
        epochs = epochs[good_epochs]
        
        n_rejected = n_before - len(epochs)
        rejection_rate = (n_rejected / n_before) * 100
        
        print(f"Artifact rejection: {n_rejected}/{n_before} epochs rejected "
              f"({rejection_rate:.1f}%)")
        print(f"Remaining epochs: {len(epochs)}")
        
        return epochs, n_rejected
    
    def process_subject(self, subject_id, return_covariances=False):
        """
        Complete preprocessing pipeline for one subject.
        
        Parameters:
        -----------
        subject_id : int
            Subject ID (1-50)
        return_covariances : bool
            If True, compute and return covariance matrices
            
        Returns:
        --------
        epochs : mne.Epochs
            Preprocessed epochs
        info : dict
            Processing information
        covariances : np.ndarray, optional
            Covariance matrices if requested (n_epochs, n_channels, n_channels)
        labels : np.ndarray, optional
            Event labels for each epoch
        """
        print(f"\n{'='*60}")
        print(f"Processing Subject {subject_id}")
        print(f"{'='*60}")
        
        # Load data
        print("\n1. Loading data (Task 1: motor imagery + Task 2: motor execution)...")
        raw_list = self.load_subject_data(subject_id)
        
        all_epochs_list = []
        
        # Process each run separately (since each has its own event file)
        for idx, (raw, event_file_path) in enumerate(raw_list):
            print(f"\n  Processing run {idx+1}/{len(raw_list)}: {event_file_path.name}")
            
            # Channel selection
            print("  2. Selecting channels...")
            raw_copy = raw.copy()  # Work on a copy
            raw_copy = self.select_channels(raw_copy)
            
            # Temporal filtering (beta band)
            print("  3. Applying beta band filter...")
            raw_copy = self.apply_temporal_filtering(raw_copy)
            
            # Spatial filtering
            print("  4. Applying spatial filtering...")
            raw_copy = self.apply_spatial_filtering(raw_copy)
            
            # Extract epochs using event file
            print("  5. Extracting epochs...")
            epochs_run = self.extract_epochs(raw_copy, event_file_path)
            
            all_epochs_list.append(epochs_run)
        
        # Concatenate epochs from all runs
        print("\n6. Combining epochs from all runs...")
        epochs = mne.concatenate_epochs(all_epochs_list, verbose=False)
        print(f"Total epochs after combining: {len(epochs)}")
        
        # Artifact rejection
        print("\n7. Rejecting artifacts...")
        epochs, n_rejected = self.reject_artifacts(epochs)
        
        # Get labels - create reverse mapping manually
        labels = epochs.events[:, -1]  # Last column contains event codes
        event_id_rev = {v: k for k, v in epochs.event_id.items()}
        label_names = np.array([event_id_rev[code] for code in labels])
        
        # Prepare info
        info = {
            'subject_id': subject_id,
            'n_epochs': len(epochs),
            'n_rejected': n_rejected,
            'event_counts': {key: len(epochs[key]) for key in epochs.event_id.keys()}
        }
        
        print(f"\nEvent distribution:")
        for event, count in info['event_counts'].items():
            print(f"  {event}: {count} epochs")
        
        if return_covariances:
            print("\n8. Computing covariance matrices...")
            covariances = self.compute_covariances(epochs)
            print(f"Computed {covariances.shape[0]} covariance matrices "
                  f"({covariances.shape[1]}x{covariances.shape[2]})")
            return epochs, info, covariances, label_names
        
        return epochs, info, label_names
    
    def compute_covariances(self, epochs):
        """
        Compute covariance matrices for each epoch with Ledoit-Wolf shrinkage.
        
        Parameters:
        -----------
        epochs : mne.Epochs
            Preprocessed epochs
            
        Returns:
        --------
        covariances : np.ndarray
            Covariance matrices (n_epochs, n_channels, n_channels)
        """
        from sklearn.covariance import LedoitWolf
        
        data = epochs.get_data()  # Shape: (n_epochs, n_channels, n_times)
        n_epochs, n_channels, n_times = data.shape
        
        covariances = np.zeros((n_epochs, n_channels, n_channels))
        
        # Apply Ledoit-Wolf shrinkage for each epoch
        lw = LedoitWolf()
        
        for i in range(n_epochs):
            # Transpose to (n_times, n_channels) for sklearn
            epoch_data = data[i].T
            
            # Compute regularized covariance
            cov = lw.fit(epoch_data).covariance_
            covariances[i] = cov
        
        return covariances
    
    def process_all_subjects(self, subject_ids=None, save_path=None):
        """
        Process multiple subjects and optionally save results.
        
        Parameters:
        -----------
        subject_ids : list, optional
            List of subject IDs to process (1-50). If None, processes all 50.
        save_path : str, optional
            Path to save preprocessed data
            
        Returns:
        --------
        all_data : dict
            Dictionary containing all processed data
        """
        if subject_ids is None:
            subject_ids = range(1, 51)  # All 50 subjects
        
        all_covariances = []
        all_labels = []
        all_info = []
        
        for subject_id in subject_ids:
            try:
                epochs, info, covariances, labels = self.process_subject(
                    subject_id, 
                    return_covariances=True
                )
                
                all_covariances.append(covariances)
                all_labels.append(labels)
                all_info.append(info)
                
            except Exception as e:
                print(f"\nError processing subject {subject_id}: {str(e)}")
                continue
        
        # Combine all data
        all_covariances = np.vstack(all_covariances)
        all_labels = np.concatenate(all_labels)
        
        print(f"\n{'='*60}")
        print(f"Processing Complete")
        print(f"{'='*60}")
        print(f"Successfully processed {len(all_info)}/{len(subject_ids)} subjects")
        print(f"Total epochs: {all_covariances.shape[0]}")
        print(f"\nLabel distribution:")
        unique, counts = np.unique(all_labels, return_counts=True)
        for label, count in zip(unique, counts):
            print(f"  {label}: {count} epochs")
        
        all_data = {
            'covariances': all_covariances,
            'labels': all_labels,
            'info': all_info
        }
        
        if save_path:
            self.save_preprocessed_data(all_data, save_path)
        
        return all_data
    
    def save_preprocessed_data(self, all_data, save_path):
        """
        Save preprocessed data to disk.
        
        Parameters:
        -----------
        all_data : dict
            Dictionary containing covariances, labels, and info
        save_path : str
            Path to save data
        """
        save_path = Path(save_path)
        save_path.mkdir(parents=True, exist_ok=True)
        
        # Save covariances and labels
        np.save(save_path / 'covariances.npy', all_data['covariances'])
        np.save(save_path / 'labels.npy', all_data['labels'])
        
        # Save info as JSON
        import json
        with open(save_path / 'preprocessing_info.json', 'w') as f:
            json.dump(all_data['info'], f, indent=2)
        
        print(f"\nData saved to {save_path}")
        print(f"  - covariances.npy: {all_data['covariances'].shape}")
        print(f"  - labels.npy: {all_data['labels'].shape}")
        print(f"  - preprocessing_info.json")


# Example usage
if __name__ == "__main__":
    # Initialize preprocessor with your data directory
    preprocessor = PhysioNetPreprocessor(
        data_dir=r"C:\Users\fibof\Downloads\physionet"
    )
    
    # Option 1: Process a single subject
    print("Processing single subject example...")
    epochs, info, covariances, labels = preprocessor.process_subject(
        subject_id=1, 
        return_covariances=True
    )
    
    print("\n" + "="*60)
    print("Single Subject Results:")
    print("="*60)
    print(f"Covariance matrices shape: {covariances.shape}")
    print(f"Labels shape: {labels.shape}")
    print(f"Unique labels: {np.unique(labels)}")
    
    # Option 2: Process all 50 subjects
    # Uncomment to run:
    """
    print("\n\nProcessing all subjects...")
    all_data = preprocessor.process_all_subjects(
        subject_ids=range(1, 51),  # All 50 subjects
        save_path='./preprocessed_data'
    )
    
    print("\n" + "="*60)
    print("All Subjects Results:")
    print("="*60)
    print(f"Total covariance matrices: {all_data['covariances'].shape}")
    print(f"Total labels: {all_data['labels'].shape}")
    """

Processing single subject example...

Processing Subject 1

1. Loading data (Task 1: motor imagery + Task 2: motor execution)...

  Processing run 1/6: S001R03.edf.event
  2. Selecting channels...
Available channels: ['Fc5.', 'Fc3.', 'Fc1.', 'Fcz.', 'Fc2.', 'Fc4.', 'Fc6.', 'C5..', 'C3..', 'C1..']... (showing first 10)
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Selected 9 channels: ['C3..', 'Cz..', 'C4..', 'C1..', 'C2..', 'C5..', 'C6..', 'Cp5.', 'Cp3.']
  3. Applying beta band filter...
Applied beta band filter: 0.05-45.0 Hz
Applied notch filter: 60.0 Hz
  4. Applying spatial filtering...
Applied Common Average Reference (CAR)
  5. Extracting epochs...
Attempting to extract events from EDF annotations...
  Found 30 annotations in EDF file
  Successfully extracted events from annotations
  Event mapping: {'rest': 1, 'left': 2, 'right': 3}
Found 30 events
Event mapping: {'rest': 1, 'left': 2, 'right': 3}
  rest: 15 events
  left: 8 events
  right: 7 ev

In [19]:
"""
Phase 2: Train Minimum Distance to Mean (MDM) Classifier on Healthy Subjects
Uses preprocessed data from Phase 1 to learn Riemannian class means.

This file is self-contained with all necessary Riemannian geometry classes.
Only requires PhysioNetPreprocessor from Phase 1.
"""

import numpy as np
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.covariance import LedoitWolf
import matplotlib.pyplot as plt
import seaborn as sns


# ============================================================================
# RIEMANNIAN GEOMETRY CLASSES (from Phase 2-3 artifact)
# ============================================================================

class RiemannianGeometry:
    """
    Riemannian geometry operations for SPD (Symmetric Positive Definite) matrices.
    """
    
    @staticmethod
    def sqrtm_psd(matrix):
        """Compute matrix square root for positive semi-definite matrix."""
        eigenvalues, eigenvectors = np.linalg.eigh(matrix)
        eigenvalues = np.maximum(eigenvalues, 1e-10)
        sqrt_eigenvalues = np.sqrt(eigenvalues)
        return eigenvectors @ np.diag(sqrt_eigenvalues) @ eigenvectors.T
    
    @staticmethod
    def invsqrtm_psd(matrix):
        """Compute inverse matrix square root for positive semi-definite matrix."""
        eigenvalues, eigenvectors = np.linalg.eigh(matrix)
        eigenvalues = np.maximum(eigenvalues, 1e-10)
        invsqrt_eigenvalues = 1.0 / np.sqrt(eigenvalues)
        return eigenvectors @ np.diag(invsqrt_eigenvalues) @ eigenvectors.T
    
    @staticmethod
    def distance_riemann(A, B):
        """
        Compute Riemannian distance between two SPD matrices.
        
        d(A, B) = ||log(A^(-1/2) * B * A^(-1/2))||_F
        """
        A_invsqrt = RiemannianGeometry.invsqrtm_psd(A)
        C = A_invsqrt @ B @ A_invsqrt
        
        eigenvalues = np.linalg.eigvalsh(C)
        eigenvalues = np.maximum(eigenvalues, 1e-10)
        log_eigenvalues = np.log(eigenvalues)
        
        distance = np.sqrt(np.sum(log_eigenvalues**2))
        return distance
    
    @staticmethod
    def mean_riemann(covmats, tol=1e-6, max_iter=50):
        """
        Compute Riemannian mean (Fréchet mean) of SPD matrices.
        
        The Riemannian mean minimizes sum of squared Riemannian distances.
        """
        n_matrices, n_channels, _ = covmats.shape
        
        # Initialize with Euclidean mean
        mean_cov = np.mean(covmats, axis=0)
        
        for iteration in range(max_iter):
            mean_invsqrt = RiemannianGeometry.invsqrtm_psd(mean_cov)
            mean_sqrt = RiemannianGeometry.sqrtm_psd(mean_cov)
            
            log_sum = np.zeros((n_channels, n_channels))
            
            for i in range(n_matrices):
                C_transported = mean_invsqrt @ covmats[i] @ mean_invsqrt
                
                eigenvalues, eigenvectors = np.linalg.eigh(C_transported)
                eigenvalues = np.maximum(eigenvalues, 1e-10)
                log_eigenvalues = np.log(eigenvalues)
                log_C = eigenvectors @ np.diag(log_eigenvalues) @ eigenvectors.T
                
                log_sum += log_C
            
            log_mean = log_sum / n_matrices
            
            eigenvalues, eigenvectors = np.linalg.eigh(log_mean)
            exp_eigenvalues = np.exp(eigenvalues)
            exp_log_mean = eigenvectors @ np.diag(exp_eigenvalues) @ eigenvectors.T
            
            new_mean = mean_sqrt @ exp_log_mean @ mean_sqrt
            
            change = np.linalg.norm(new_mean - mean_cov, 'fro')
            mean_cov = new_mean
            
            if change < tol:
                break
        
        return mean_cov


class MinimumDistanceToMean:
    """
    Minimum Distance to Mean (MDM) classifier in Riemannian space.
    
    Classification is based on Riemannian distance to class means.
    """
    
    def __init__(self):
        self.class_means_ = {}
        self.classes_ = None
    
    def fit(self, covariances, labels):
        """
        Fit MDM classifier by computing Riemannian mean for each class.
        """
        self.classes_ = np.unique(labels)
        
        print(f"Computing Riemannian means for {len(self.classes_)} classes...")
        
        for cls in self.classes_:
            class_mask = labels == cls
            class_covs = covariances[class_mask]
            
            class_mean = RiemannianGeometry.mean_riemann(class_covs)
            self.class_means_[cls] = class_mean
            
            print(f"  Class '{cls}': {np.sum(class_mask)} samples, "
                  f"mean shape {class_mean.shape}")
        
        return self
    
    def predict(self, covariances):
        """
        Predict class labels based on nearest Riemannian mean.
        """
        n_samples = covariances.shape[0]
        predictions = np.zeros(n_samples, dtype=self.classes_.dtype)
        
        for i in range(n_samples):
            distances = {}
            
            for cls in self.classes_:
                dist = RiemannianGeometry.distance_riemann(
                    covariances[i], 
                    self.class_means_[cls]
                )
                distances[cls] = dist
            
            predictions[i] = min(distances, key=distances.get)
        
        return predictions
    
    def predict_proba(self, covariances):
        """
        Predict class probabilities based on Riemannian distances.
        
        Uses softmax of negative distances.
        """
        n_samples = covariances.shape[0]
        n_classes = len(self.classes_)
        distances = np.zeros((n_samples, n_classes))
        
        for i in range(n_samples):
            for j, cls in enumerate(self.classes_):
                distances[i, j] = RiemannianGeometry.distance_riemann(
                    covariances[i],
                    self.class_means_[cls]
                )
        
        # Convert distances to probabilities using softmax
        exp_neg_dist = np.exp(-distances)
        probabilities = exp_neg_dist / np.sum(exp_neg_dist, axis=1, keepdims=True)
        
        return probabilities
    
    def get_class_means(self):
        """Get the learned Riemannian means."""
        return self.class_means_


# ============================================================================
# TRAINING CLASS
# ============================================================================

class HealthySubjectTrainer:
    """
    Train MDM classifier on healthy subject data.
    
    This learns the Riemannian means for rest and movement classes
    which will later be used for transfer learning to PD patients.
    """
    
    def __init__(self):
        self.mdm_classifier = None
        self.class_means = None
        self.training_subjects = []
        self.training_accuracy = None
        
    def prepare_data(self, preprocessor, subject_ids, exclude_rest=False):
        """
        Process multiple subjects and prepare training data.
        
        Parameters:
        -----------
        preprocessor : PhysioNetPreprocessor
            Preprocessor instance from Phase 1
        subject_ids : list
            List of subject IDs to include in training
        exclude_rest : bool
            If True, only use left and right classes (exclude rest/T0)
            Recommended: True, since T0 rest periods don't have MRCPs
            
        Returns:
        --------
        all_covs : np.ndarray
            All covariance matrices (n_total_epochs, 9, 9)
        all_labels : np.ndarray
            All labels (n_total_epochs,)
        subject_groups : np.ndarray
            Subject ID for each epoch (for cross-validation)
        """
        all_covs = []
        all_labels = []
        subject_groups = []
        
        print("="*60)
        print("PHASE 2: COLLECTING TRAINING DATA")
        print("="*60)
        if exclude_rest:
            print("Mode: 2-class classification (left vs right) - EXCLUDING REST")
            print("Rationale: MRCPs only occur before movement, not during rest")
        else:
            print("Mode: 3-class classification (rest, left, right)")
        
        for subject_id in subject_ids:
            try:
                print(f"\nProcessing subject {subject_id}...")
                epochs, info, covs, labels = preprocessor.process_subject(
                    subject_id, 
                    return_covariances=True
                )
                
                # Exclude rest if requested
                if exclude_rest:
                    mask = labels != 'rest'
                    covs = covs[mask]
                    labels = labels[mask]
                    print(f"  Excluded {np.sum(~mask)} rest epochs")
                
                all_covs.append(covs)
                all_labels.append(labels)
                subject_groups.append(np.full(len(labels), subject_id))
                
                self.training_subjects.append(subject_id)
                
                print(f"  ✓ Subject {subject_id}: {len(labels)} epochs collected")
                
            except Exception as e:
                print(f"  ✗ Error with subject {subject_id}: {str(e)}")
                continue
        
        # Combine all data
        all_covs = np.vstack(all_covs)
        all_labels = np.concatenate(all_labels)
        subject_groups = np.concatenate(subject_groups)
        
        print("\n" + "="*60)
        print("DATA COLLECTION COMPLETE")
        print("="*60)
        print(f"Total subjects: {len(self.training_subjects)}")
        print(f"Total epochs: {len(all_labels)}")
        print(f"Covariance shape: {all_covs.shape}")
        print("\nLabel distribution:")
        unique, counts = np.unique(all_labels, return_counts=True)
        for label, count in zip(unique, counts):
            print(f"  {label}: {count} epochs ({count/len(all_labels)*100:.1f}%)")
        
        return all_covs, all_labels, subject_groups
    
    def train_mdm(self, covs, labels):
        """
        Train MDM classifier on covariance matrices.
        
        This computes the Riemannian mean for each class:
        - M_rest_healthy: Mean covariance for rest class
        - M_left_healthy: Mean covariance for left hand class
        - M_right_healthy: Mean covariance for right hand class
        
        Parameters:
        -----------
        covs : np.ndarray
            Covariance matrices (n_epochs, 9, 9)
        labels : np.ndarray
            Class labels (n_epochs,)
            
        Returns:
        --------
        mdm : MinimumDistanceToMean
            Trained classifier
        """
        print("\n" + "="*60)
        print("TRAINING MDM CLASSIFIER")
        print("="*60)
        
        # DEBUG: Check if covariances are valid
        print("\nDEBUG: Checking covariance matrices...")
        print(f"  Covariance shape: {covs.shape}")
        print(f"  Min value: {np.min(covs):.6e}")
        print(f"  Max value: {np.max(covs):.6e}")
        print(f"  Mean value: {np.mean(covs):.6e}")
        print(f"  Any NaN: {np.any(np.isnan(covs))}")
        print(f"  Any Inf: {np.any(np.isinf(covs))}")
        
        # Check a sample covariance matrix
        sample_cov = covs[0]
        print(f"\n  Sample covariance matrix [0]:")
        print(f"    Trace: {np.trace(sample_cov):.6e}")
        print(f"    Determinant: {np.linalg.det(sample_cov):.6e}")
        print(f"    Condition number: {np.linalg.cond(sample_cov):.2f}")
        print(f"    Is symmetric: {np.allclose(sample_cov, sample_cov.T)}")
        
        # Initialize and train MDM classifier
        mdm = MinimumDistanceToMean()
        mdm.fit(covs, labels)
        
        # Get the learned class means
        self.class_means = mdm.get_class_means()
        self.mdm_classifier = mdm
        
        print("\n" + "="*60)
        print("LEARNED CLASS MEANS")
        print("="*60)
        for class_name, mean_matrix in self.class_means.items():
            print(f"\nClass: {class_name}")
            print(f"  Shape: {mean_matrix.shape}")
            print(f"  Condition number: {np.linalg.cond(mean_matrix):.2f}")
            print(f"  Trace (sum of variances): {np.trace(mean_matrix):.6e}")
            print(f"  Determinant: {np.linalg.det(mean_matrix):.6e}")
            
            # Check distances between class means
            print(f"\n  Distances from {class_name} mean to other means:")
            for other_name, other_mean in self.class_means.items():
                if other_name != class_name:
                    dist = RiemannianGeometry.distance_riemann(mean_matrix, other_mean)
                    print(f"    to {other_name}: {dist:.6f}")
        
        return mdm
    
    def evaluate_within_subject(self, covs, labels, subject_groups):
        """
        Evaluate classifier using Leave-One-Subject-Out cross-validation.
        
        This tests how well the model generalizes to new subjects
        (important for transfer learning).
        
        Parameters:
        -----------
        covs : np.ndarray
            Covariance matrices
        labels : np.ndarray
            Class labels
        subject_groups : np.ndarray
            Subject ID for each epoch
            
        Returns:
        --------
        cv_scores : dict
            Cross-validation results
        """
        print("\n" + "="*60)
        print("CROSS-VALIDATION (Leave-One-Subject-Out)")
        print("="*60)
        
        logo = LeaveOneGroupOut()
        cv_accuracies = []
        
        for train_idx, test_idx in logo.split(covs, labels, subject_groups):
            # Split data
            train_covs, test_covs = covs[train_idx], covs[test_idx]
            train_labels, test_labels = labels[train_idx], labels[test_idx]
            
            # Train on training subjects
            mdm_cv = MinimumDistanceToMean()
            mdm_cv.fit(train_covs, train_labels)
            
            # Test on held-out subject
            predictions = mdm_cv.predict(test_covs)
            accuracy = accuracy_score(test_labels, predictions)
            cv_accuracies.append(accuracy)
        
        cv_scores = {
            'mean_accuracy': np.mean(cv_accuracies),
            'std_accuracy': np.std(cv_accuracies),
            'all_scores': cv_accuracies
        }
        
        print(f"Cross-validation results ({len(cv_accuracies)} folds):")
        print(f"  Mean accuracy: {cv_scores['mean_accuracy']:.3f} ± {cv_scores['std_accuracy']:.3f}")
        print(f"  Min accuracy: {np.min(cv_accuracies):.3f}")
        print(f"  Max accuracy: {np.max(cv_accuracies):.3f}")
        
        return cv_scores
    
    def evaluate_final_model(self, covs, labels):
        """
        Evaluate the final trained model on all training data.
        
        This shows training accuracy and confusion matrix.
        
        Parameters:
        -----------
        covs : np.ndarray
            Covariance matrices
        labels : np.ndarray
            Class labels
        """
        print("\n" + "="*60)
        print("FINAL MODEL EVALUATION")
        print("="*60)
        
        # Predict on training data
        predictions = self.mdm_classifier.predict(covs)
        probabilities = self.mdm_classifier.predict_proba(covs)
        
        # Calculate accuracy
        accuracy = accuracy_score(labels, predictions)
        self.training_accuracy = accuracy
        
        print(f"\nTraining accuracy: {accuracy:.3f} ({accuracy*100:.1f}%)")
        
        # Confusion matrix
        cm = confusion_matrix(labels, predictions)
        
        print("\nConfusion Matrix:")
        unique_labels = np.unique(labels)
        print(f"           Predicted")
        print(f"           ", end="")
        for label in unique_labels:
            print(f"{label:>8}", end="")
        print()
        for i, true_label in enumerate(unique_labels):
            print(f"Actual {true_label:>6} ", end="")
            for j in range(len(unique_labels)):
                print(f"{cm[i, j]:>8}", end="")
            print()
        
        # Classification report
        print("\nClassification Report:")
        print(classification_report(labels, predictions, digits=3))
        
        return {
            'accuracy': accuracy,
            'confusion_matrix': cm,
            'predictions': predictions,
            'probabilities': probabilities
        }
    
    def visualize_class_means(self):
        """
        Visualize the learned Riemannian class means.
        
        Shows covariance matrices as heatmaps.
        """
        n_classes = len(self.class_means)
        fig, axes = plt.subplots(1, n_classes, figsize=(5*n_classes, 4))
        
        if n_classes == 1:
            axes = [axes]
        
        for idx, (class_name, mean_matrix) in enumerate(self.class_means.items()):
            ax = axes[idx]
            
            # Plot heatmap
            sns.heatmap(mean_matrix, 
                       annot=False, 
                       cmap='RdBu_r', 
                       center=0,
                       square=True,
                       ax=ax,
                       cbar_kws={'label': 'Covariance'})
            
            ax.set_title(f'Class: {class_name}\n(Riemannian Mean)', 
                        fontsize=12, fontweight='bold')
            ax.set_xlabel('Channel')
            ax.set_ylabel('Channel')
        
        plt.tight_layout()
        plt.savefig('riemannian_class_means.png', dpi=150, bbox_inches='tight')
        print("\nSaved visualization: riemannian_class_means.png")
        plt.show()
    
    def save_model(self, filepath='mdm_healthy_model.npz'):
        """
        Save trained model and class means.
        
        Parameters:
        -----------
        filepath : str
            Path to save model
        """
        # Convert class means dict to saveable format
        class_means_dict = {}
        for class_name, mean_matrix in self.class_means.items():
            class_means_dict[f'mean_{class_name}'] = mean_matrix
        
        np.savez(
            filepath,
            training_subjects=self.training_subjects,
            training_accuracy=self.training_accuracy,
            **class_means_dict
        )
        print(f"\nModel saved to: {filepath}")
        print(f"Saved class means: {list(self.class_means.keys())}")


# ============================================================================
# EXAMPLE USAGE
# ============================================================================

if __name__ == "__main__":
    # Assumes you have PhysioNetPreprocessor available
    # from your Phase 1 code
    
    print("="*60)
    print("PHASE 2: TRAIN MDM CLASSIFIER ON HEALTHY SUBJECTS")
    print("="*60)
    
    # Initialize preprocessor (assumes it's already defined)
    # preprocessor = PhysioNetPreprocessor()
    
    # Initialize trainer
    trainer = HealthySubjectTrainer()
    
    # Example: Prepare training data (use first 10 subjects for demo)
    # Uncomment and modify the following lines:
    
    """
    subject_ids = range(1, 11)  # Use subjects 1-10 for demo
    # Or use all 50: subject_ids = range(1, 51)
    
    covs, labels, subject_groups = trainer.prepare_data(
        preprocessor, 
        subject_ids
    )
    
    # Train MDM classifier
    mdm = trainer.train_mdm(covs, labels)
    
    # Cross-validation
    cv_scores = trainer.evaluate_within_subject(covs, labels, subject_groups)
    
    # Evaluate final model
    results = trainer.evaluate_final_model(covs, labels)
    
    # Visualize class means
    trainer.visualize_class_means()
    
    # Save model (will save M_rest_healthy, M_left_healthy, M_right_healthy)
    trainer.save_model('mdm_healthy_model.npz')
    
    print("\n" + "="*60)
    print("PHASE 2 COMPLETE!")
    print("="*60)
    print(f"Trained on {len(trainer.training_subjects)} subjects")
    print(f"Cross-validation accuracy: {cv_scores['mean_accuracy']:.3f}")
    print(f"Training accuracy: {results['accuracy']:.3f}")
    print("\nLearned class means (135 parameters total):")
    print("  M_rest_healthy: 9x9 matrix (45 unique values)")
    print("  M_left_healthy: 9x9 matrix (45 unique values)")
    print("  M_right_healthy: 9x9 matrix (45 unique values)")
    print("\nReady for Phase 4 (Transfer Learning)")
    """
    
    print("\nReady to use! Uncomment the example code above to train.")

PHASE 2: TRAIN MDM CLASSIFIER ON HEALTHY SUBJECTS

Ready to use! Uncomment the example code above to train.


In [20]:
covs, labels, subject_groups = trainer.prepare_data(
    preprocessor, 
    range(1, 11),
    exclude_rest=True  # Only train on left vs right
)

PHASE 2: COLLECTING TRAINING DATA
Mode: 2-class classification (left vs right) - EXCLUDING REST
Rationale: MRCPs only occur before movement, not during rest

Processing subject 1...

Processing Subject 1

1. Loading data (Task 1: motor imagery + Task 2: motor execution)...

  Processing run 1/6: S001R03.edf.event
  2. Selecting channels...
Available channels: ['Fc5.', 'Fc3.', 'Fc1.', 'Fcz.', 'Fc2.', 'Fc4.', 'Fc6.', 'C5..', 'C3..', 'C1..']... (showing first 10)
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Selected 9 channels: ['C3..', 'Cz..', 'C4..', 'C1..', 'C2..', 'C5..', 'C6..', 'Cp5.', 'Cp3.']
  3. Applying beta band filter...
Applied beta band filter: 13.0-30.0 Hz
Applied notch filter: 60.0 Hz
  4. Applying spatial filtering...
Applied Common Average Reference (CAR)
  5. Extracting epochs...
Attempting to extract events from EDF annotations...
  Found 30 annotations in EDF file
  Successfully extracted events from annotations
  Event mapping: {'re

In [28]:
"""
Phase 2.5: Comprehensive Analysis and Visualization Suite

This script provides detailed analysis and visualization of the preprocessing
and training pipeline to support the research question:

"Can transfer learning trained on healthy subjects enable accurate movement 
intention detection for Parkinson's patients with minimal patient-specific 
calibration, despite systematic neurophysiological differences between populations?"

Outputs:
- Excel/CSV files with detailed metrics
- Visualizations of each processing step
- Statistical analysis for research validation
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.model_selection import LeaveOneGroupOut
import warnings
warnings.filterwarnings('ignore')

# Set publication-quality plotting style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")


class ComprehensiveAnalyzer:
    """
    Comprehensive analysis suite for BCI transfer learning research.
    
    Analyzes and visualizes:
    1. Raw EEG signal quality
    2. Covariance matrix properties
    3. Class separability in Riemannian space
    4. Classifier performance metrics
    5. Cross-subject generalization (key for transfer learning)
    """
    
    def __init__(self, output_dir='analysis_results'):
        """
        Initialize analyzer.
        
        Parameters:
        -----------
        output_dir : str
            Directory to save all outputs
        """
        from pathlib import Path
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        self.results = {}
        self.figures = []
        
    # ========================================================================
    # 1. SIGNAL QUALITY ANALYSIS
    # ========================================================================
    
    def analyze_signal_quality(self, epochs_list, subject_ids):
        """
        Analyze raw EEG signal quality after preprocessing.
        
        Research relevance: Validates that preprocessing preserves 
        discriminative information needed for transfer learning.
        
        Parameters:
        -----------
        epochs_list : list of mne.Epochs
            Preprocessed epochs for each subject
        subject_ids : list
            Subject IDs
            
        Returns:
        --------
        df : pd.DataFrame
            Signal quality metrics
        """
        print("\n" + "="*60)
        print("1. SIGNAL QUALITY ANALYSIS")
        print("="*60)
        
        metrics = []
        
        for subject_id, epochs in zip(subject_ids, epochs_list):
            data = epochs.get_data()  # (n_epochs, n_channels, n_times)
            
            # Create reverse mapping for event IDs
            event_id_rev = {v: k for k, v in epochs.event_id.items()}
            
            for label_code in np.unique(epochs.events[:, -1]):
                label_name = event_id_rev.get(label_code, str(label_code))
                label_mask = epochs.events[:, -1] == label_code
                label_data = data[label_mask]
                
                if len(label_data) == 0:
                    continue
                
                # Compute metrics
                mean_signal = np.mean(label_data**2)
                var_signal = np.var(label_data)
                
                metrics.append({
                    'subject_id': subject_id,
                    'class': label_name,
                    'n_epochs': len(label_data),
                    'mean_amplitude_uV': np.mean(np.abs(label_data)) * 1e6,
                    'std_amplitude_uV': np.std(label_data) * 1e6,
                    'snr_db': 10 * np.log10(mean_signal / var_signal) if var_signal > 0 else 0,
                    'peak_amplitude_uV': np.max(np.abs(label_data)) * 1e6,
                })
        
        df = pd.DataFrame(metrics)
        
        # Save to Excel
        excel_path = self.output_dir / 'signal_quality_metrics.xlsx'
        df.to_excel(excel_path, index=False)
        print(f"✓ Saved signal quality metrics to: {excel_path}")
        
        # Summary statistics
        if len(df) > 0:
            print("\nSignal Quality Summary:")
            print(df.groupby('class')[['mean_amplitude_uV', 'snr_db']].describe())
        else:
            print("\nWarning: No signal quality data collected")
        
        # Visualize
        if len(df) > 0:
            self._plot_signal_quality(df)
        
        self.results['signal_quality'] = df
        return df
    
    def _plot_signal_quality(self, df):
        """Plot signal quality metrics."""
        fig, axes = plt.subplots(1, 2, figsize=(12, 4))
        
        # Amplitude by class
        sns.boxplot(data=df, x='class', y='mean_amplitude_uV', ax=axes[0])
        axes[0].set_title('Signal Amplitude by Class (Beta Band)')
        axes[0].set_ylabel('Mean Amplitude (μV)')
        axes[0].set_xlabel('Class')
        
        # SNR by class
        sns.boxplot(data=df, x='class', y='snr_db', ax=axes[1])
        axes[1].set_title('Signal-to-Noise Ratio by Class')
        axes[1].set_ylabel('SNR (dB)')
        axes[1].set_xlabel('Class')
        
        plt.tight_layout()
        path = self.output_dir / 'signal_quality.png'
        plt.savefig(path, dpi=300, bbox_inches='tight')
        print(f"✓ Saved figure: {path}")
        plt.close()
        
        self.figures.append(path)
    
    # ========================================================================
    # 2. COVARIANCE MATRIX ANALYSIS
    # ========================================================================
    
    def analyze_covariances(self, covs, labels, subject_groups):
        """
        Analyze covariance matrix properties.
        
        Research relevance: Covariance matrices are the SPD manifold
        representations used for transfer learning. Their properties
        determine separability and transferability.
        
        Parameters:
        -----------
        covs : np.ndarray
            Covariance matrices (n_epochs, n_channels, n_channels)
        labels : np.ndarray
            Class labels
        subject_groups : np.ndarray
            Subject IDs
            
        Returns:
        --------
        df : pd.DataFrame
            Covariance properties
        """
        print("\n" + "="*60)
        print("2. COVARIANCE MATRIX ANALYSIS")
        print("="*60)
        
        metrics = []
        
        for i in range(len(covs)):
            cov = covs[i]
            label = labels[i]
            subject = subject_groups[i]
            
            # Compute properties
            eigenvalues = np.linalg.eigvalsh(cov)
            
            metrics.append({
                'epoch': i,
                'subject_id': subject,
                'class': label,
                'trace': np.trace(cov),
                'determinant': np.linalg.det(cov),
                'condition_number': np.linalg.cond(cov),
                'frobenius_norm': np.linalg.norm(cov, 'fro'),
                'min_eigenvalue': np.min(eigenvalues),
                'max_eigenvalue': np.max(eigenvalues),
                'eigenvalue_ratio': np.max(eigenvalues) / np.min(eigenvalues),
                'is_positive_definite': np.all(eigenvalues > 0),
            })
        
        df = pd.DataFrame(metrics)
        
        # Save to Excel
        excel_path = self.output_dir / 'covariance_properties.xlsx'
        df.to_excel(excel_path, index=False)
        print(f"✓ Saved covariance properties to: {excel_path}")
        
        # Summary by class
        print("\nCovariance Properties by Class:")
        summary = df.groupby('class')[['trace', 'determinant', 'condition_number']].describe()
        print(summary)
        
        # Check for issues
        print(f"\nPotential Issues:")
        print(f"  Non-positive definite matrices: {np.sum(~df['is_positive_definite'])}")
        print(f"  High condition numbers (>1000): {np.sum(df['condition_number'] > 1000)}")
        
        # Visualize
        self._plot_covariance_properties(df)
        
        self.results['covariance_properties'] = df
        return df
    
    def _plot_covariance_properties(self, df):
        """Plot covariance properties."""
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        
        # Trace (total variance)
        sns.violinplot(data=df, x='class', y='trace', ax=axes[0, 0])
        axes[0, 0].set_title('Covariance Trace (Total Variance)')
        axes[0, 0].set_ylabel('Trace')
        axes[0, 0].set_xlabel('Class')
        
        # Condition number
        sns.violinplot(data=df, x='class', y='condition_number', ax=axes[0, 1])
        axes[0, 1].set_title('Condition Number (Numerical Stability)')
        axes[0, 1].set_ylabel('Condition Number')
        axes[0, 1].set_xlabel('Class')
        axes[0, 1].set_yscale('log')
        
        # Eigenvalue spectrum
        for label in df['class'].unique():
            class_df = df[df['class'] == label]
            axes[1, 0].scatter(class_df['min_eigenvalue'], class_df['max_eigenvalue'], 
                              label=label, alpha=0.5)
        axes[1, 0].set_xlabel('Min Eigenvalue')
        axes[1, 0].set_ylabel('Max Eigenvalue')
        axes[1, 0].set_title('Eigenvalue Spectrum')
        axes[1, 0].legend()
        axes[1, 0].set_xscale('log')
        axes[1, 0].set_yscale('log')
        
        # Frobenius norm
        sns.boxplot(data=df, x='class', y='frobenius_norm', ax=axes[1, 1])
        axes[1, 1].set_title('Frobenius Norm (Matrix Magnitude)')
        axes[1, 1].set_ylabel('Frobenius Norm')
        axes[1, 1].set_xlabel('Class')
        
        plt.tight_layout()
        path = self.output_dir / 'covariance_properties.png'
        plt.savefig(path, dpi=300, bbox_inches='tight')
        print(f"✓ Saved figure: {path}")
        plt.close()
        
        self.figures.append(path)
    
    # ========================================================================
    # 3. RIEMANNIAN DISTANCE ANALYSIS
    # ========================================================================
    
    def analyze_class_separability(self, covs, labels, class_means, RiemannianGeometry):
        """
        Analyze class separability in Riemannian space.
        
        Research relevance: Distance between class means determines
        how well the classifier can distinguish classes. This is critical
        for transfer learning effectiveness.
        
        Parameters:
        -----------
        covs : np.ndarray
            Covariance matrices
        labels : np.ndarray
            Class labels
        class_means : dict
            Riemannian means for each class
        RiemannianGeometry : class
            Riemannian geometry class with distance_riemann method
            
        Returns:
        --------
        df_distances : pd.DataFrame
            Pairwise distances
        df_separability : pd.DataFrame
            Within-class vs between-class distances
        """
        print("\n" + "="*60)
        print("3. CLASS SEPARABILITY ANALYSIS (Riemannian Space)")
        print("="*60)
        
        # Compute pairwise distances between class means
        class_names = list(class_means.keys())
        n_classes = len(class_names)
        
        distance_matrix = np.zeros((n_classes, n_classes))
        
        for i, class1 in enumerate(class_names):
            for j, class2 in enumerate(class_names):
                dist = RiemannianGeometry.distance_riemann(
                    class_means[class1],
                    class_means[class2]
                )
                distance_matrix[i, j] = dist
        
        # Create DataFrame for class mean distances
        df_class_distances = pd.DataFrame(
            distance_matrix,
            index=class_names,
            columns=class_names
        )
        
        print("\nPairwise Distances Between Class Means:")
        print(df_class_distances)
        
        # Compute within-class and between-class distances
        separability_metrics = []
        
        for class_name in class_names:
            class_mask = labels == class_name
            class_covs = covs[class_mask]
            class_mean = class_means[class_name]
            
            # Within-class distances
            within_distances = []
            for cov in class_covs:
                dist = RiemannianGeometry.distance_riemann(cov, class_mean)
                within_distances.append(dist)
            
            # Between-class distances (to other class means)
            between_distances = []
            for other_class, other_mean in class_means.items():
                if other_class != class_name:
                    dist = RiemannianGeometry.distance_riemann(class_mean, other_mean)
                    between_distances.append(dist)
            
            separability_metrics.append({
                'class': class_name,
                'n_samples': len(class_covs),
                'within_class_mean': np.mean(within_distances),
                'within_class_std': np.std(within_distances),
                'between_class_min': np.min(between_distances) if between_distances else 0,
                'between_class_mean': np.mean(between_distances) if between_distances else 0,
                'separability_ratio': (np.mean(between_distances) / np.mean(within_distances)) 
                                      if within_distances and between_distances else 0,
            })
        
        df_separability = pd.DataFrame(separability_metrics)
        
        # Save to Excel with multiple sheets
        excel_path = self.output_dir / 'class_separability.xlsx'
        with pd.ExcelWriter(excel_path) as writer:
            df_class_distances.to_excel(writer, sheet_name='Class_Mean_Distances')
            df_separability.to_excel(writer, sheet_name='Separability_Metrics', index=False)
        
        print(f"✓ Saved class separability to: {excel_path}")
        
        print("\nSeparability Metrics:")
        print(df_separability)
        
        print("\nInterpretation:")
        print("  Separability ratio = between_class_mean / within_class_mean")
        print("  Higher ratio = better class separation")
        print("  Ratio > 2.0 is generally good for classification")
        
        # Visualize
        self._plot_class_separability(df_class_distances, df_separability)
        
        self.results['class_distances'] = df_class_distances
        self.results['separability'] = df_separability
        
        return df_class_distances, df_separability
    
    def _plot_class_separability(self, df_distances, df_separability):
        """Plot class separability."""
        fig, axes = plt.subplots(1, 2, figsize=(14, 5))
        
        # Heatmap of class mean distances
        sns.heatmap(df_distances, annot=True, fmt='.3f', cmap='YlOrRd', 
                   ax=axes[0], cbar_kws={'label': 'Riemannian Distance'})
        axes[0].set_title('Pairwise Distances Between Class Means')
        
        # Separability ratio
        axes[1].bar(df_separability['class'], df_separability['separability_ratio'])
        axes[1].axhline(y=2.0, color='r', linestyle='--', label='Good threshold (2.0)')
        axes[1].set_xlabel('Class')
        axes[1].set_ylabel('Separability Ratio')
        axes[1].set_title('Class Separability\n(Between/Within Distance Ratio)')
        axes[1].legend()
        axes[1].grid(axis='y', alpha=0.3)
        
        plt.tight_layout()
        path = self.output_dir / 'class_separability.png'
        plt.savefig(path, dpi=300, bbox_inches='tight')
        print(f"✓ Saved figure: {path}")
        plt.close()
        
        self.figures.append(path)
    
    # ========================================================================
    # 4. CLASSIFIER PERFORMANCE ANALYSIS
    # ========================================================================
    
    def analyze_classifier_performance(self, mdm_classifier, covs, labels, subject_groups):
        """
        Comprehensive classifier performance analysis.
        
        Research relevance: Within-subject and cross-subject performance
        indicates how well the model generalizes - critical for transfer
        learning to PD patients.
        
        Parameters:
        -----------
        mdm_classifier : MinimumDistanceToMean
            Trained classifier
        covs : np.ndarray
            Covariance matrices
        labels : np.ndarray
            True labels
        subject_groups : np.ndarray
            Subject IDs
            
        Returns:
        --------
        df_overall : pd.DataFrame
            Overall performance metrics
        df_per_subject : pd.DataFrame
            Per-subject performance
        """
        print("\n" + "="*60)
        print("4. CLASSIFIER PERFORMANCE ANALYSIS")
        print("="*60)
        
        # Overall performance
        predictions = mdm_classifier.predict(covs)
        probabilities = mdm_classifier.predict_proba(covs)
        
        accuracy = accuracy_score(labels, predictions)
        cm = confusion_matrix(labels, predictions)
        
        print(f"\nOverall Accuracy: {accuracy:.3f} ({accuracy*100:.1f}%)")
        
        # Per-class metrics
        report = classification_report(labels, predictions, output_dict=True)
        df_overall = pd.DataFrame(report).transpose()
        
        # Leave-One-Subject-Out Cross-Validation
        print("\nPerforming Leave-One-Subject-Out Cross-Validation...")
        logo = LeaveOneGroupOut()
        
        cv_results = []
        unique_subjects = np.unique(subject_groups)
        
        for train_idx, test_idx in logo.split(covs, labels, subject_groups):
            train_covs, test_covs = covs[train_idx], covs[test_idx]
            train_labels, test_labels = labels[train_idx], labels[test_idx]
            
            # Get test subject ID
            test_subject = subject_groups[test_idx][0]
            
            # Train on all other subjects
            mdm_cv = MinimumDistanceToMean()
            mdm_cv.fit(train_covs, train_labels)
            
            # Test on held-out subject
            preds = mdm_cv.predict(test_covs)
            acc = accuracy_score(test_labels, preds)
            
            # Per-class accuracy
            class_accs = {}
            for cls in np.unique(labels):
                cls_mask = test_labels == cls
                if np.sum(cls_mask) > 0:
                    class_accs[f'{cls}_accuracy'] = accuracy_score(
                        test_labels[cls_mask], 
                        preds[cls_mask]
                    )
            
            cv_results.append({
                'test_subject': test_subject,
                'n_train_samples': len(train_labels),
                'n_test_samples': len(test_labels),
                'accuracy': acc,
                **class_accs
            })
        
        df_per_subject = pd.DataFrame(cv_results)
        
        print(f"\nCross-Validation Results:")
        print(f"  Mean Accuracy: {df_per_subject['accuracy'].mean():.3f} ± {df_per_subject['accuracy'].std():.3f}")
        print(f"  Min Accuracy: {df_per_subject['accuracy'].min():.3f}")
        print(f"  Max Accuracy: {df_per_subject['accuracy'].max():.3f}")
        
        # Save to Excel
        excel_path = self.output_dir / 'classifier_performance.xlsx'
        with pd.ExcelWriter(excel_path) as writer:
            df_overall.to_excel(writer, sheet_name='Overall_Metrics')
            df_per_subject.to_excel(writer, sheet_name='Per_Subject_CV', index=False)
            
            # Add confusion matrix
            cm_df = pd.DataFrame(cm, 
                                index=[f'True_{c}' for c in np.unique(labels)],
                                columns=[f'Pred_{c}' for c in np.unique(labels)])
            cm_df.to_excel(writer, sheet_name='Confusion_Matrix')
        
        print(f"✓ Saved performance metrics to: {excel_path}")
        
        # Visualize
        self._plot_classifier_performance(cm, df_per_subject, labels, predictions)
        
        self.results['overall_performance'] = df_overall
        self.results['per_subject_performance'] = df_per_subject
        self.results['confusion_matrix'] = cm
        
        return df_overall, df_per_subject
    
    def _plot_classifier_performance(self, cm, df_per_subject, labels, predictions):
        """Plot classifier performance."""
        fig = plt.figure(figsize=(15, 5))
        gs = fig.add_gridspec(1, 3)
        
        # Confusion matrix
        ax1 = fig.add_subplot(gs[0, 0])
        unique_labels = np.unique(labels)
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                   xticklabels=unique_labels, yticklabels=unique_labels,
                   ax=ax1)
        ax1.set_xlabel('Predicted')
        ax1.set_ylabel('True')
        ax1.set_title('Confusion Matrix\n(All Subjects)')
        
        # Cross-validation accuracy distribution
        ax2 = fig.add_subplot(gs[0, 1])
        ax2.hist(df_per_subject['accuracy'], bins=10, edgecolor='black', alpha=0.7)
        ax2.axvline(df_per_subject['accuracy'].mean(), color='r', linestyle='--', 
                   label=f"Mean: {df_per_subject['accuracy'].mean():.3f}")
        ax2.set_xlabel('Accuracy')
        ax2.set_ylabel('Number of Subjects')
        ax2.set_title('Cross-Subject Generalization\n(Leave-One-Subject-Out)')
        ax2.legend()
        ax2.grid(axis='y', alpha=0.3)
        
        # Per-subject accuracy
        ax3 = fig.add_subplot(gs[0, 2])
        subjects = df_per_subject['test_subject'].values
        accuracies = df_per_subject['accuracy'].values
        ax3.bar(range(len(subjects)), accuracies)
        ax3.axhline(y=df_per_subject['accuracy'].mean(), color='r', linestyle='--',
                   label='Mean')
        ax3.set_xlabel('Subject ID')
        ax3.set_ylabel('Accuracy')
        ax3.set_title('Per-Subject Accuracy (LOSO-CV)')
        ax3.set_xticks(range(len(subjects)))
        ax3.set_xticklabels(subjects, rotation=45)
        ax3.legend()
        ax3.grid(axis='y', alpha=0.3)
        
        plt.tight_layout()
        path = self.output_dir / 'classifier_performance.png'
        plt.savefig(path, dpi=300, bbox_inches='tight')
        print(f"✓ Saved figure: {path}")
        plt.close()
        
        self.figures.append(path)
    
    # ========================================================================
    # 5. TRANSFER LEARNING READINESS REPORT
    # ========================================================================
    
    def generate_transfer_learning_report(self):
        """
        Generate comprehensive report on transfer learning readiness.
        
        Research relevance: Summarizes all metrics relevant to answering
        the research question about healthy-to-PD transfer learning.
        """
        print("\n" + "="*60)
        print("5. TRANSFER LEARNING READINESS REPORT")
        print("="*60)
        
        report = []
        
        # 1. Data quality
        if 'signal_quality' in self.results:
            sq = self.results['signal_quality']
            report.append({
                'Metric': 'Signal Quality - Mean Amplitude (μV)',
                'Value': f"{sq['mean_amplitude_uV'].mean():.2f} ± {sq['mean_amplitude_uV'].std():.2f}",
                'Interpretation': 'Good' if sq['mean_amplitude_uV'].mean() > 1.0 else 'Low',
                'Relevance': 'Higher amplitude = stronger signal for transfer'
            })
        
        # 2. Class separability
        if 'separability' in self.results:
            sep = self.results['separability']
            mean_sep = sep['separability_ratio'].mean()
            report.append({
                'Metric': 'Class Separability Ratio',
                'Value': f"{mean_sep:.2f}",
                'Interpretation': 'Good' if mean_sep > 2.0 else 'Poor',
                'Relevance': 'Higher ratio = easier to transfer learned patterns'
            })
        
        # 3. Cross-subject generalization
        if 'per_subject_performance' in self.results:
            cv = self.results['per_subject_performance']
            mean_acc = cv['accuracy'].mean()
            std_acc = cv['accuracy'].std()
            report.append({
                'Metric': 'Cross-Subject Accuracy (LOSO-CV)',
                'Value': f"{mean_acc:.3f} ± {std_acc:.3f}",
                'Interpretation': 'Good' if mean_acc > 0.65 else 'Poor',
                'Relevance': 'CRITICAL: Predicts transfer learning performance'
            })
            
            report.append({
                'Metric': 'Cross-Subject Std Dev',
                'Value': f"{std_acc:.3f}",
                'Interpretation': 'Stable' if std_acc < 0.15 else 'Variable',
                'Relevance': 'Lower variance = more reliable transfer'
            })
        
        # 4. Covariance stability
        if 'covariance_properties' in self.results:
            cov_props = self.results['covariance_properties']
            mean_cond = cov_props['condition_number'].mean()
            report.append({
                'Metric': 'Mean Condition Number',
                'Value': f"{mean_cond:.1f}",
                'Interpretation': 'Good' if mean_cond < 100 else 'Needs regularization',
                'Relevance': 'Lower = more numerically stable for adaptation'
            })
        
        df_report = pd.DataFrame(report)
        
        # Save report
        excel_path = self.output_dir / 'transfer_learning_readiness_report.xlsx'
        df_report.to_excel(excel_path, index=False)
        print(f"\n✓ Saved transfer learning readiness report to: {excel_path}")
        
        print("\n" + "="*60)
        print("TRANSFER LEARNING READINESS SUMMARY")
        print("="*60)
        print(df_report.to_string(index=False))
        
        return df_report
    
    # ========================================================================
    # 6. GENERATE MASTER SUMMARY
    # ========================================================================
    
    def generate_master_database(self):
        """
        Combine all results into a master database file.
        
        Creates a comprehensive Excel file with all analysis results
        in separate sheets for easy analysis.
        """
        print("\n" + "="*60)
        print("6. GENERATING MASTER DATABASE")
        print("="*60)
        
        master_path = self.output_dir / 'MASTER_ANALYSIS_DATABASE.xlsx'
        
        with pd.ExcelWriter(master_path, engine='openpyxl') as writer:
            # Write all dataframes to separate sheets
            for name, df in self.results.items():
                if isinstance(df, pd.DataFrame):
                    sheet_name = name[:31]  # Excel sheet name limit
                    df.to_excel(writer, sheet_name=sheet_name, index=False)
                    print(f"  ✓ Added sheet: {sheet_name}")
        
        print(f"\n✓ Master database saved to: {master_path}")
        print(f"\nDatabase contains {len(self.results)} sheets with comprehensive metrics")
        
        return master_path


# ============================================================================
# MAIN EXECUTION FUNCTION
# ============================================================================

def run_comprehensive_analysis(preprocessor, trainer, subject_ids, 
                               exclude_rest=True, output_dir='analysis_results'):
    """
    Run complete analysis pipeline.
    
    Parameters:
    -----------
    preprocessor : PhysioNetPreprocessor
        Preprocessor from Phase 1
    trainer : HealthySubjectTrainer
        Trainer from Phase 2
    subject_ids : list
        Subject IDs to analyze
    exclude_rest : bool
        Whether to exclude rest class
    output_dir : str
        Output directory
        
    Returns:
    --------
    analyzer : ComprehensiveAnalyzer
        Analyzer with all results
    """
    print("\n" + "="*80)
    print(" COMPREHENSIVE ANALYSIS FOR TRANSFER LEARNING RESEARCH")
    print("="*80)
    
    # Initialize analyzer
    analyzer = ComprehensiveAnalyzer(output_dir=output_dir)
    
    # Collect preprocessed epochs
    print("\nCollecting preprocessed epochs...")
    epochs_list = []
    for subject_id in subject_ids:
        try:
            epochs, _, _, _ = preprocessor.process_subject(subject_id, return_covariances=False)
            epochs_list.append(epochs)
        except Exception as e:
            print(f"Error with subject {subject_id}: {e}")
    
    # 1. Signal quality analysis
    analyzer.analyze_signal_quality(epochs_list, subject_ids)
    
    # Prepare training data
    covs, labels, subject_groups = trainer.prepare_data(
        preprocessor, 
        subject_ids,
        exclude_rest=exclude_rest
    )
    
    # 2. Covariance analysis
    analyzer.analyze_covariances(covs, labels, subject_groups)
    
    # Train classifier
    mdm = trainer.train_mdm(covs, labels)
    
    # 3. Class separability analysis
    analyzer.analyze_class_separability(
        covs, labels, 
        trainer.class_means, 
        RiemannianGeometry
    )
    
    # 4. Classifier performance analysis
    analyzer.analyze_classifier_performance(
        trainer.mdm_classifier,
        covs, labels, subject_groups
    )
    
    # 5. Transfer learning readiness report
    analyzer.generate_transfer_learning_report()
    
    # 6. Generate master database
    analyzer.generate_master_database()
    
    print("\n" + "="*80)
    print(" ANALYSIS COMPLETE!")
    print("="*80)
    print(f"\nAll results saved to: {analyzer.output_dir}")
    print(f"Generated {len(analyzer.figures)} figures")
    print(f"\nKey files:")
    print(f"  - MASTER_ANALYSIS_DATABASE.xlsx (all metrics)")
    print(f"  - transfer_learning_readiness_report.xlsx (summary)")
    print(f"  - classifier_performance.xlsx (performance metrics)")
    print(f"  - class_separability.xlsx (separability analysis)")
    
    return analyzer


# Example usage
if __name__ == "__main__":
    """
    # Assuming you have preprocessor and trainer already initialized
    
    analyzer = run_comprehensive_analysis(
        preprocessor=preprocessor,
        trainer=trainer,
        subject_ids=range(1, 11),  # Or range(1, 51) for all subjects
        exclude_rest=True,
        output_dir='analysis_results'
    )
    
    print("\n\nTo answer the research question:")
    print("'Can transfer learning trained on healthy subjects enable accurate")
    print(" movement intention detection for Parkinson's patients?'")
    print("\nKey metrics to examine:")
    print("1. Cross-subject accuracy (LOSO-CV) - predicts PD transfer performance")
    print("2. Class separability ratio - indicates transferability")
    print("3. Covariance stability - important for domain adaptation")
    """
    pass

In [29]:
# After running Phase 2 training
analyzer = run_comprehensive_analysis(
    preprocessor=preprocessor,
    trainer=trainer,
    subject_ids=range(1, 11),  # or all 50
    exclude_rest=True,
    output_dir='analysis_results'
)


 COMPREHENSIVE ANALYSIS FOR TRANSFER LEARNING RESEARCH

Collecting preprocessed epochs...

Processing Subject 1

1. Loading data (Task 1: motor imagery + Task 2: motor execution)...

  Processing run 1/6: S001R03.edf.event
  2. Selecting channels...
Available channels: ['Fc5.', 'Fc3.', 'Fc1.', 'Fcz.', 'Fc2.', 'Fc4.', 'Fc6.', 'C5..', 'C3..', 'C1..']... (showing first 10)
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Selected 9 channels: ['C3..', 'Cz..', 'C4..', 'C1..', 'C2..', 'C5..', 'C6..', 'Cp5.', 'Cp3.']
  3. Applying beta band filter...
Applied beta band filter: 13.0-30.0 Hz
Applied notch filter: 60.0 Hz
  4. Applying spatial filtering...
Applied Common Average Reference (CAR)
  5. Extracting epochs...
Attempting to extract events from EDF annotations...
  Found 30 annotations in EDF file
  Successfully extracted events from annotations
  Event mapping: {'rest': 1, 'left': 2, 'right': 3}
Found 30 events
Event mapping: {'rest': 1, 'left': 2, 'right

KeyError: 'mean_amplitude_uV'