In [11]:
import os
import numpy as np
from tqdm import tqdm
import mne
import gc
from scipy.io import loadmat
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import (
    accuracy_score, classification_report, confusion_matrix,
    roc_curve, auc, precision_recall_curve, average_precision_score
)
import matplotlib.pyplot as plt
import seaborn as sns
import joblib
from scipy.signal import welch
import pywt
import pandas as pd
from matplotlib.gridspec import GridSpec
import re
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.decomposition import PCA
from imblearn.over_sampling import SMOTE
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier

# Configure plotting
plt.rcParams['figure.figsize'] = [12, 8]
plt.rcParams['font.size'] = 12
sns.set_style("whitegrid")
sns.set_palette("colorblind")

class EEGPipeline:
    def __init__(self):
        self.models = {}
        self.results = {}
        self.scaler = None
        self.label_encoder = None
        self.common_channels = None
        self.feature_names = [
            'Delta Power', 'Theta Power', 'Alpha Power', 'Beta Power', 'Gamma Power',
            'Wavelet Mean 1', 'Wavelet Mean 2', 'Wavelet Mean 3', 'Wavelet Mean 4', 'Wavelet Mean 5',
            'Mean', 'Std Dev', 'Median', 'Skewness', 'Kurtosis',
            'Hjorth Mobility', 'Hjorth Complexity', 'Spectral Entropy', 'Zero-Crossings', 'Peak-to-Peak'
        ]
        self.band_names = ['Delta', 'Theta', 'Alpha', 'Beta', 'Gamma']
        self.channel_mapping = {}
        self.expected_features_per_channel = len(self.feature_names)
        self._error_messages = []  # Track errors during pipeline execution

    def set_channel_mapping(self, mapping_dict):
        """Set manual channel name mapping between different naming conventions"""
        self.channel_mapping = mapping_dict

    def normalize_channel_name(self, channel_name):
        """Advanced channel name normalization with manual mapping support"""
        if isinstance(channel_name, (list, np.ndarray)):
            channel_name = channel_name[0]
        channel_name = str(channel_name).strip().upper()
        if channel_name in self.channel_mapping:
            return self.channel_mapping[channel_name]
        channel_name = re.sub(r'[^A-Z0-9]', '', channel_name)
        channel_name = re.sub(r'^CH', '', channel_name)
        channel_name = re.sub(r'^EEG', '', channel_name)
        channel_name = channel_name.lstrip('0')
        variations = {
            'FP1': 'Fp1', 'FP2': 'Fp2',
            'T3': 'T7', 'T4': 'T8',
            'T5': 'P7', 'T6': 'P8'
        }
        return variations.get(channel_name, channel_name)

    def get_channel_names_from_mat(self, mat_path):
        """Robust MAT file channel extraction supporting multiple formats"""
        try:
            mat_data = loadmat(mat_path)
            channels = []
            if 'data' in mat_data:
                data_struct = mat_data['data'][0][0]
                if 'channels' in data_struct.dtype.names:
                    channels = [str(ch[0]) for ch in data_struct['channels'][0]]
                elif 'chanlocs' in data_struct.dtype.names:
                    chanlocs = data_struct['chanlocs'][0]
                    channels = [str(chan['labels'][0]) for chan in chanlocs]
            elif 'EEG' in mat_data:
                eeg_struct = mat_data['EEG'][0][0]
                if 'chanlocs' in eeg_struct.dtype.names:
                    chanlocs = eeg_struct['chanlocs'][0]
                    channels = [str(chan['labels'][0]) for chan in chanlocs]
                elif 'chaninfo' in eeg_struct.dtype.names:
                    chaninfo = eeg_struct['chaninfo'][0][0]
                    if 'labels' in chaninfo.dtype.names:
                        channels = [str(ch[0]) for ch in chaninfo['labels'][0]]
            elif 'X' in mat_data and 'ch_names' in mat_data:
                channels = [str(ch[0]) for ch in mat_data['ch_names'][0]]
            return [self.normalize_channel_name(ch) for ch in channels if ch and str(ch).strip()]
        except Exception as e:
            error_msg = f"Error loading {mat_path}: {str(e)}"
            self._error_messages.append(error_msg)
            print(error_msg)
            return []

    def preprocess_data(self, data, labels, sfreq=250, common_channels=None):
        """Enhanced preprocessing with additional features"""
        try:
            if data is None or len(data) == 0:
                raise ValueError("Empty data array")
            n_channels = data.shape[1]
            ch_names = common_channels[:n_channels] if common_channels else [f'ch{i}' for i in range(n_channels)]
            info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types='eeg')
            raw = mne.io.RawArray(data.T, info)
            nyquist = sfreq / 2
            raw.filter(0.5, min(40, nyquist-1), fir_design='firwin', phase='zero-double')
            notch_freqs = [50, 60]
            notch_freqs = [f for f in notch_freqs if f < nyquist]
            if notch_freqs:
                raw.notch_filter(notch_freqs)
            events = mne.make_fixed_length_events(raw, duration=1.0)
            epochs = mne.Epochs(raw, events, tmin=0, tmax=1.0, baseline=None, preload=True)
            epochs_data = epochs.get_data()

            def extract_features(epoch_data):
                features = []
                for epoch in epoch_data:
                    epoch_features = []
                    for channel in epoch:
                        # Frequency features
                        freqs, psd = welch(channel, fs=sfreq, nperseg=min(256, len(channel)))
                        bands = {'delta': (0.5, 4), 'theta': (4, 8), 'alpha': (8, 12),
                                 'beta': (12, 30), 'gamma': (30, min(50, nyquist-1))}
                        band_powers = [np.sum(psd[(freqs >= low) & (freqs <= high)]) 
                                     for low, high in bands.values()]
                        # Wavelet features
                        coeffs = pywt.wavedec(channel, 'db4', level=4)
                        wavelet_features = [np.mean(c) for c in coeffs[:5]]
                        if len(wavelet_features) < 5:
                            wavelet_features += [0.0] * (5 - len(wavelet_features))
                        # Statistical features
                        stats = [
                            np.mean(channel), np.std(channel), np.median(channel),
                            pd.Series(channel).skew(), pd.Series(channel).kurtosis()
                        ]
                        # Hjorth parameters
                        mobility, complexity = self.hjorth_parameters(channel)
                        # Spectral entropy
                        spectral_entropy = -np.sum(psd * np.log(psd + 1e-10))
                        # Zero-crossing rate
                        zero_crossings = np.where(np.diff(np.sign(channel)))[0].size
                        # Peak-to-peak amplitude
                        peak_to_peak = np.max(channel) - np.min(channel)
                        epoch_features.extend(
                            band_powers + wavelet_features + stats + 
                            [mobility, complexity, spectral_entropy, zero_crossings, peak_to_peak]
                        )
                    features.append(epoch_features)
                return np.array(features)

            X = extract_features(epochs_data)
            y = labels[:len(X)]
            expected_features = n_channels * self.expected_features_per_channel
            if X.shape[1] != expected_features:
                if X.shape[1] < expected_features:
                    pad_width = ((0, 0), (0, expected_features - X.shape[1]))
                    X = np.pad(X, pad_width, mode='constant')
                else:
                    X = X[:, :expected_features]
            return X, y
        except Exception as e:
            error_msg = f"Error during preprocessing (sfreq={sfreq}): {str(e)}"
            self._error_messages.append(error_msg)
            print(error_msg)
            return np.array([]), np.array([])

    def hjorth_parameters(self, signal):
        """Calculate Hjorth mobility and complexity parameters"""
        first_deriv = np.diff(signal)
        second_deriv = np.diff(signal, 2)
        var_zero = np.var(signal)
        var_d1 = np.var(first_deriv)
        var_d2 = np.var(second_deriv)
        mobility = np.sqrt(var_d1 / var_zero)
        complexity = np.sqrt(var_d2 / var_d1) / mobility
        return mobility, complexity

    def load_preprocessing_artifacts(self, scaler_path, label_encoder_path):
        """Load preprocessing artifacts with dimension validation"""
        try:
            self.scaler = joblib.load(scaler_path)
            print(f"✓ Scaler loaded successfully (expecting {self.scaler.n_features_in_} features)")
        except Exception as e:
            error_msg = f"✗ Failed to load scaler: {str(e)}"
            self._error_messages.append(error_msg)
            print(error_msg)
            self.scaler = None
        try:
            self.label_encoder = joblib.load(label_encoder_path)
            print("✓ Label encoder loaded successfully")
        except Exception as e:
            error_msg = f"✗ Failed to load label encoder: {str(e)}"
            self._error_messages.append(error_msg)
            print(error_msg)
            self.label_encoder = None

    def tune_models(self, X_train, y_train):
        """Hyperparameter tuning using grid search"""
        rf_param_grid = {
            'n_estimators': [100, 200],
            'max_depth': [10, 20, None],
            'min_samples_split': [2, 5],
            'min_samples_leaf': [1, 2]
        }
        rf = RandomForestClassifier(random_state=42)
        rf_grid = GridSearchCV(rf, rf_param_grid, cv=3, scoring='accuracy', n_jobs=-1)
        rf_grid.fit(X_train, y_train)
        best_rf = rf_grid.best_estimator_

        xgb_param_grid = {
            'n_estimators': [100, 200],
            'max_depth': [3, 5, 7],
            'learning_rate': [0.01, 0.1]
        }
        xgb = XGBClassifier(random_state=42, use_label_encoder=False, eval_metric='logloss')
        xgb_grid = GridSearchCV(xgb, xgb_param_grid, cv=3, scoring='accuracy', n_jobs=-1)
        xgb_grid.fit(X_train, y_train)
        best_xgb = xgb_grid.best_estimator_

        self.models = {
            "RF": best_rf,
            "XGB": best_xgb
        }

    def run_pipeline(self, config):
        results = None
        try:
            print("\n=== EEG Analysis Pipeline ===\n")
            # 1. Load data
            healthy_data, healthy_labels, h_sfreq = self.load_dataset(
                config['healthy_path'],
                max_files=config.get('max_files'),
                max_duration=config.get('max_duration', None))
            patient_data, patient_labels, p_sfreq = self.load_dataset(
                config['patient_path'],
                is_patient=True,
                max_files=config.get('max_files'),
                max_duration=config.get('max_duration', None))

            # 2. Channel matching
            if not config.get('skip_channel_matching', False):
                healthy_channels = self.collect_channel_names(config['healthy_path'])
                patient_channels = self.collect_channel_names(config['patient_path'], is_patient=True)
                self.common_channels = self.find_common_channels(healthy_channels, patient_channels)

            # 3. Preprocessing
            X_healthy, y_healthy = self.preprocess_data(
                healthy_data, healthy_labels,
                sfreq=h_sfreq,
                common_channels=self.common_channels)
            X_patients, y_patients = self.preprocess_data(
                patient_data, patient_labels,
                sfreq=p_sfreq,
                common_channels=self.common_channels)

            # 4. Apply transformations
            self.scaler = StandardScaler()
            X = np.vstack([X_healthy, X_patients])
            y = np.hstack([y_healthy, y_patients])
            X = self.scaler.fit_transform(X)

            # 5. Handle class imbalance
            smote = SMOTE(random_state=42)
            X_resampled, y_resampled = smote.fit_resample(X, y)

            # 6. Dimensionality reduction
            pca = PCA(n_components=0.95, random_state=42)
            X_reduced = pca.fit_transform(X_resampled)

            # 7. Train-test split
            X_train, X_test, y_train, y_test = train_test_split(X_reduced, y_resampled, test_size=0.2, random_state=42)

            # 8. Model tuning
            self.tune_models(X_train, y_train)

            # 9. Evaluate models
            self.evaluate_models(X_test, y_test, data_type='Patient')

            # 10. Generate visualizations
            self.plot_model_performance_comparison()

            print("\n=== Pipeline Completed ===\n")
            if self._error_messages:
                print("Completed with warnings/errors (see above messages)")
            else:
                print("Completed successfully!")
        except Exception as e:
            print(f"\n!!! Pipeline Failed !!!\nError: {str(e)}")
            import traceback
            traceback.print_exc()
        finally:
            gc.collect()
        return results

if __name__ == "__main__":
    channel_mapping = {
        'FP1': 'Fp1',
        'FP2': 'Fp2',
        'F3': 'F3',
        'F4': 'F4',
    }
    config = {
        'healthy_path': "F:/shivani/VSCode/ml/worked on dataset/3(final)/Healthy",
        'patient_path': "F:/shivani/VSCode/ml/worked on dataset/4/dataset/Patients",
        'model_paths': {},
        'scaler_path': "F:/shivani/VSCode/ml/worked on dataset/3(final)/preprocessing_artifacts/eeg_scaler.joblib",
        'label_encoder_path': "F:/shivani/VSCode/ml/worked on dataset/3(final)/preprocessing_artifacts/label_encoder.joblib",
        'max_files': 5,
    }
    pipeline = EEGPipeline()
    pipeline.set_channel_mapping(channel_mapping)
    results = pipeline.run_pipeline(config)


=== EEG Analysis Pipeline ===


!!! Pipeline Failed !!!
Error: 'EEGPipeline' object has no attribute 'load_dataset'


Traceback (most recent call last):
  File "C:\Users\shivani\AppData\Local\Temp\ipykernel_7468\357773262.py", line 236, in run_pipeline
    healthy_data, healthy_labels, h_sfreq = self.load_dataset(
                                            ^^^^^^^^^^^^^^^^^
AttributeError: 'EEGPipeline' object has no attribute 'load_dataset'


In [12]:
import os   #alallalaallalalalala
import numpy as np
from tqdm import tqdm
import mne
import gc
from scipy.io import loadmat
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import (
    accuracy_score, classification_report, confusion_matrix,
    roc_curve, auc, precision_recall_curve, average_precision_score
)
import matplotlib.pyplot as plt
import seaborn as sns
import joblib
from scipy.signal import welch
import pywt
import pandas as pd
from matplotlib.gridspec import GridSpec
import re
from sklearn.model_selection import train_test_split
from matplotlib.ticker import FormatStrFormatter

# Configure plotting
plt.rcParams['figure.figsize'] = [12, 8]
plt.rcParams['font.size'] = 12
sns.set_style("whitegrid")
sns.set_palette("colorblind")

class EEGPipeline:
    def __init__(self):
        self.models = {}
        self.results = {}
        self.scaler = None
        self.label_encoder = None
        self.common_channels = None
        self.feature_names = [
            'Delta Power', 'Theta Power', 'Alpha Power', 'Beta Power', 'Gamma Power',
            'Wavelet Mean 1', 'Wavelet Mean 2', 'Wavelet Mean 3', 'Wavelet Mean 4', 'Wavelet Mean 5',
            'Mean', 'Std Dev', 'Median'
        ]
        self.band_names = ['Delta', 'Theta', 'Alpha', 'Beta', 'Gamma']
        self.channel_mapping = {}
        self.expected_features_per_channel = 13
        self._error_messages = []  # Track errors during pipeline execution

    def set_channel_mapping(self, mapping_dict):
        """Set manual channel name mapping between different naming conventions"""
        self.channel_mapping = mapping_dict

    def normalize_channel_name(self, channel_name):
        """Advanced channel name normalization with manual mapping support"""
        if isinstance(channel_name, (list, np.ndarray)):
            channel_name = channel_name[0]
        channel_name = str(channel_name).strip().upper()
        
        if channel_name in self.channel_mapping:
            return self.channel_mapping[channel_name]
        
        channel_name = re.sub(r'[^A-Z0-9]', '', channel_name)
        channel_name = re.sub(r'^CH', '', channel_name)
        channel_name = re.sub(r'^EEG', '', channel_name)
        channel_name = channel_name.lstrip('0')
        
        # Handle common variations
        variations = {
            'FP1': 'Fp1', 'FP2': 'Fp2',
            'T3': 'T7', 'T4': 'T8',
            'T5': 'P7', 'T6': 'P8'
        }
        return variations.get(channel_name, channel_name)

    def get_channel_names_from_mat(self, mat_path):
        """Robust MAT file channel extraction supporting multiple formats"""
        try:
            mat_data = loadmat(mat_path)
            channels = []
            
            # Structure 1: Nested 'data' structure
            if 'data' in mat_data:
                data_struct = mat_data['data'][0][0]
                if 'channels' in data_struct.dtype.names:
                    channels = [str(ch[0]) for ch in data_struct['channels'][0]]
                elif 'chanlocs' in data_struct.dtype.names:
                    chanlocs = data_struct['chanlocs'][0]
                    channels = [str(chan['labels'][0]) for chan in chanlocs]
            
            # Structure 2: EEGLAB structure
            elif 'EEG' in mat_data:
                eeg_struct = mat_data['EEG'][0][0]
                if 'chanlocs' in eeg_struct.dtype.names:
                    chanlocs = eeg_struct['chanlocs'][0]
                    channels = [str(chan['labels'][0]) for chan in chanlocs]
                elif 'chaninfo' in eeg_struct.dtype.names:
                    chaninfo = eeg_struct['chaninfo'][0][0]
                    if 'labels' in chaninfo.dtype.names:
                        channels = [str(ch[0]) for ch in chaninfo['labels'][0]]
            
            # Structure 3: Simple X,y structure
            elif 'X' in mat_data and 'ch_names' in mat_data:
                channels = [str(ch[0]) for ch in mat_data['ch_names'][0]]
                
            return [self.normalize_channel_name(ch) for ch in channels if ch and str(ch).strip()]
        except Exception as e:
            error_msg = f"Error loading {mat_path}: {str(e)}"
            self._error_messages.append(error_msg)
            print(error_msg)
            return []

    def get_channel_names_from_vhdr(self, vhdr_path):
        """Extract channel names from BrainVision files with validation"""
        try:
            raw = mne.io.read_raw_brainvision(vhdr_path, preload=False, verbose=False)
            return [self.normalize_channel_name(ch) for ch in raw.ch_names]
        except Exception as e:
            error_msg = f"Error loading {vhdr_path}: {str(e)}"
            self._error_messages.append(error_msg)
            print(error_msg)
            return []

    def find_common_channels(self, healthy_channels, patient_channels):
        """Flexible channel matching with multiple strategies"""
        # First try exact matching
        common = set(healthy_channels).intersection(patient_channels)
        
        if not common:
            healthy_set = set(healthy_channels)
            patient_set = set(patient_channels)
            common = healthy_set.intersection(patient_set)
            
        if not common:
            common_partial = set()
            for h_ch in healthy_set:
                for p_ch in patient_set:
                    if h_ch in p_ch or p_ch in h_ch:
                        common_partial.add(h_ch)
            if common_partial:
                print(f"Using partial channel matches: {common_partial}")
                return sorted(common_partial)
            
        print(f"Found {len(common)} common channels")
        return sorted(common)

    def collect_channel_names(self, folder_path, is_patient=False):
        """Collect channel names with extensive validation"""
        all_channels = set()
        files = [f for f in os.listdir(folder_path) if f.endswith('.vhdr' if not is_patient else '.mat')]
        
        if not files:
            raise FileNotFoundError(f"No valid files found in {folder_path}")
            
        for file in tqdm(files, desc=f"Collecting {'patient' if is_patient else 'healthy'} channels"):
            file_path = os.path.join(folder_path, file)
            channels = self.get_channel_names_from_mat(file_path) if is_patient else self.get_channel_names_from_vhdr(file_path)
            
            if not channels:
                print(f"Warning: No channels found in {file}")
                continue
                
            all_channels.update(channels)
            
        if not all_channels:
            raise ValueError(f"No channels collected from {folder_path}")
            
        return sorted(all_channels)

    def load_dataset(self, folder_path, is_patient=False, max_files=None, max_duration=None):
        """Robust dataset loading with comprehensive validation"""
        files = [f for f in os.listdir(folder_path) if f.endswith('.mat' if is_patient else '.vhdr')]
        loader_func = self.load_mat_data if is_patient else self.load_brainvision_data
            
        if not files:
            raise FileNotFoundError(f"No valid files found in {folder_path}")
            
        if max_files:
            files = files[:max_files]
            
        all_data = []
        all_labels = []
        sfreqs = []
        loaded_files = 0
        
        for file in tqdm(files, desc=f"Loading {'patient' if is_patient else 'healthy'} files"):
            file_path = os.path.join(folder_path, file)
            try:
                data, labels, sfreq = loader_func(file_path, max_duration)
                
                if len(data) > 0 and len(labels) > 0:
                    all_data.append(data)
                    all_labels.append(labels)
                    sfreqs.append(sfreq)
                    loaded_files += 1
                else:
                    print(f"Skipping {file} - no valid data")
            except Exception as e:
                error_msg = f"Error loading {file}: {str(e)}"
                self._error_messages.append(error_msg)
                print(error_msg)
                
            gc.collect()
            
        print(f"\nSuccessfully loaded {loaded_files}/{len(files)} files")
        
        if not all_data:
            raise ValueError("No valid data loaded - check file formats")
            
        avg_sfreq = np.mean(sfreqs) if sfreqs else (100 if is_patient else 250)
        return np.concatenate(all_data), np.concatenate(all_labels), avg_sfreq

    def load_brainvision_data(self, vhdr_path, max_duration=None):
        """Load BrainVision data with enhanced validation"""
        try:
            raw = mne.io.read_raw_brainvision(vhdr_path, preload=True, verbose=False)
            
            if max_duration:
                crop_end = min(max_duration, raw.times[-1])
                raw.crop(tmax=crop_end)
                
            data = raw.get_data().T.astype(np.float32)
            events, _ = mne.events_from_annotations(raw)
            labels = events[:, 2] if len(events) > 0 else np.zeros(len(data))
            
            return data, labels, raw.info['sfreq']
        except Exception as e:
            error_msg = f"Error loading {vhdr_path}: {str(e)}"
            self._error_messages.append(error_msg)
            print(error_msg)
            return np.array([]), np.array([]), None

    def load_mat_data(self, mat_path, max_duration=None, default_sfreq=100):
        """Flexible MAT file loader supporting multiple structures"""
        try:
            mat_data = loadmat(mat_path)
            eeg_data = None
            labels = None
            sfreq = default_sfreq
            
            if 'data' in mat_data:
                data_struct = mat_data['data'][0][0]
                if 'X' in data_struct.dtype.names:
                    eeg_data = data_struct['X']
                    if eeg_data.ndim > 2:
                        eeg_data = eeg_data.reshape(eeg_data.shape[0], -1)
                    if 'y' in data_struct.dtype.names:
                        labels = data_struct['y'].flatten()
                    if 'sfreq' in data_struct.dtype.names:
                        sfreq = float(data_struct['sfreq'][0][0])
                    elif 'Fs' in data_struct.dtype.names:
                        sfreq = float(data_struct['Fs'][0][0])
            
            elif 'EEG' in mat_data:
                eeg_struct = mat_data['EEG'][0][0]
                if 'data' in eeg_struct.dtype.names:
                    eeg_data = eeg_struct['data'].T
                if 'event' in eeg_struct.dtype.names:
                    events = eeg_struct['event'][0]
                    labels = np.array([ev[0]['type'][0] for ev in events])
                if 'srate' in eeg_struct.dtype.names:
                    sfreq = float(eeg_struct['srate'][0][0])
            
            elif 'X' in mat_data:
                eeg_data = mat_data['X']
                if 'y' in mat_data:
                    labels = mat_data['y'].flatten()
            
            if labels is None or len(np.unique(labels)) <= 1:
                labels = np.zeros(len(eeg_data)) if eeg_data is not None else np.array([])
            
            if eeg_data is None:
                raise ValueError("No EEG data found in MAT file")
            
            min_len = min(len(eeg_data), len(labels))
            eeg_data = eeg_data[:min_len]
            labels = labels[:min_len]
            
            if max_duration:
                max_samples = int(max_duration * sfreq)
                eeg_data = eeg_data[:max_samples]
                labels = labels[:max_samples]
                
            return eeg_data.astype(np.float32), labels, sfreq
        except Exception as e:
            error_msg = f"Error loading {mat_path}: {str(e)}"
            self._error_messages.append(error_msg)
            print(error_msg)
            return np.array([]), np.array([]), None

    def preprocess_data(self, data, labels, sfreq=250, common_channels=None):
    try:
        # Validate input
        if data is None or len(data) == 0:
            raise ValueError("Empty data array")
        
        n_channels = data.shape[1]
        ch_names = common_channels[:n_channels] if common_channels and len(common_channels) >= n_channels else [f'ch{i}' for i in range(n_channels)]
        info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types='eeg')
        raw = mne.io.RawArray(data.T, info)
        
        # Apply filters with safety checks
        nyquist = sfreq / 2
        raw.filter(0.5, min(40, nyquist - 1), fir_design='firwin', phase='zero-double')
        notch_freqs = [50, 60]  # Default frequencies
        notch_freqs = [f for f in notch_freqs if f < nyquist]  # Filter out invalid freqs
        if notch_freqs:
            raw.notch_filter(notch_freqs)
        
        # Create epochs
        events = mne.make_fixed_length_events(raw, duration=1.0)
        epochs = mne.Epochs(raw, events, tmin=0, tmax=1.0, baseline=None, preload=True)
        epochs_data = epochs.get_data()

        def extract_features(epoch_data):
            features = []
            for epoch in epoch_data:
                epoch_features = []
                for channel in epoch:
                    # Frequency features
                    freqs, psd = welch(channel, fs=sfreq, nperseg=min(256, len(channel)))
                    bands = {'delta': (0.5, 4), 'theta': (4, 8), 'alpha': (8, 12),
                             'beta': (12, 30), 'gamma': (30, min(50, nyquist - 1))}
                    band_powers = [np.sum(psd[(freqs >= low) & (freqs <= high)]) 
                                   for low, high in bands.values()]
                    
                    # Wavelet features
                    coeffs = pywt.wavedec(channel, 'db4', level=4)
                    wavelet_features = [np.mean(c) for c in coeffs[:5]]
                    if len(wavelet_features) < 5:
                        wavelet_features += [0.0] * (5 - len(wavelet_features))
                    
                    # Statistical features
                    stats = [
                        np.mean(channel), 
                        np.std(channel), 
                        np.median(channel),
                        pd.Series(channel).skew(),  # Skewness
                        pd.Series(channel).kurtosis()  # Kurtosis
                    ]
                    
                    # Hjorth parameters
                    mobility, complexity = self.hjorth_parameters(channel)
                    
                    # Spectral entropy
                    spectral_entropy = -np.sum(psd * np.log(psd + 1e-10))
                    
                    # Zero-crossing rate
                    zero_crossings = np.where(np.diff(np.sign(channel)))[0].size
                    
                    # Peak-to-peak amplitude
                    peak_to_peak = np.max(channel) - np.min(channel)
                    
                    epoch_features.extend(
                        band_powers + wavelet_features + stats + 
                        [mobility, complexity, spectral_entropy, zero_crossings, peak_to_peak]
                    )
                features.append(epoch_features)
            return np.array(features)
        
        X = extract_features(epochs_data)
        y = labels[:len(X)]
        
        expected_features = n_channels * self.expected_features_per_channel
        if X.shape[1] != expected_features:
            if X.shape[1] < expected_features:
                pad_width = ((0, 0), (0, expected_features - X.shape[1]))
                X = np.pad(X, pad_width, mode='constant')
            else:
                X = X[:, :expected_features]
        return X, y
    except Exception as e:
        error_msg = f"Error during preprocessing (sfreq={sfreq}): {str(e)}"
        self._error_messages.append(error_msg)
        print(error_msg)
        return np.array([]), np.array([])

    def load_preprocessing_artifacts(self, scaler_path, label_encoder_path):
        """Load preprocessing artifacts with dimension validation"""
        try:
            self.scaler = joblib.load(scaler_path)
            print(f"✓ Scaler loaded successfully (expecting {self.scaler.n_features_in_} features)")
        except Exception as e:
            error_msg = f"✗ Failed to load scaler: {str(e)}"
            self._error_messages.append(error_msg)
            print(error_msg)
            self.scaler = None
            
        try:
            self.label_encoder = joblib.load(label_encoder_path)
            print("✓ Label encoder loaded successfully")
        except Exception as e:
            error_msg = f"✗ Failed to load label encoder: {str(e)}"
            self._error_messages.append(error_msg)
            print(error_msg)
            self.label_encoder = None

    def match_features(self, X, expected_features):
        """Ensure feature matrix matches expected dimensions"""
        if X.shape[1] == expected_features:
            return X
        elif X.shape[1] < expected_features:
            pad_width = ((0, 0), (0, expected_features - X.shape[1]))
            return np.pad(X, pad_width, mode='constant')
        else:
            return X[:, :expected_features]

    def load_models(self, model_paths):
        """Load models with comprehensive validation"""
        self.models = {}
        for name, path in model_paths.items():
            try:
                self.models[name] = joblib.load(path)
                print(f"✓ {name} model loaded successfully")
            except Exception as e:
                error_msg = f"✗ Failed to load {name} model: {str(e)}"
                self._error_messages.append(error_msg)
                print(error_msg)
                
        if not self.models:
            raise ValueError("No models were loaded successfully")
        return self.models

    def evaluate_models(self, X_data, y_data, data_type='Patient'):
        """Updated evaluation method with multiclass support and custom accuracy threshold"""
        self.results[data_type] = {}
        for name, model in self.models.items():
            try:
                y_pred = model.predict(X_data)
                
                if len(np.unique(y_data)) > 2:
                    acc = accuracy_score(y_data, y_pred)
                    report = classification_report(y_data, y_pred, output_dict=True, zero_division=0)
                    cm = confusion_matrix(y_data, y_pred)
                    roc_auc = None
                else:
                    y_proba = model.predict_proba(X_data)[:, 1] if hasattr(model, "predict_proba") else None
                    acc = accuracy_score(y_data, y_pred)
                    report = classification_report(y_data, y_pred, output_dict=True, zero_division=0)
                    cm = confusion_matrix(y_data, y_pred)
                    if y_proba is not None:
                        fpr, tpr, _ = roc_curve(y_data, y_proba)
                        roc_auc = auc(fpr, tpr)
                    else:
                        roc_auc = None
                
                # Check if accuracy meets the custom threshold of 65%
                meets_custom_threshold = acc >= 0.65
                
                self.results[data_type][name] = {
                    'accuracy': acc,
                    'meets_custom_threshold': meets_custom_threshold,  # New field for custom threshold
                    'report': report,
                    'confusion_matrix': cm,
                    'roc_auc': roc_auc,
                    'fpr': fpr if 'fpr' in locals() else None,
                    'tpr': tpr if 'tpr' in locals() else None,
                    'y_true': y_data,
                    'y_pred': y_pred
                }
            except Exception as e:
                error_msg = f"Error evaluating {name} on {data_type} data: {str(e)}"
                self._error_messages.append(error_msg)
                print(error_msg)
    def plot_eeg_comparison(self, healthy_data, patient_data, healthy_sfreq=250, patient_sfreq=100, samples=500, channels=3):
        """Enhanced horizontal EEG signal comparison plot"""
        plt.style.use('seaborn-v0_8-whitegrid')
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8), sharey=True)
        
        time_h = np.arange(min(samples, len(healthy_data))) / healthy_sfreq
        time_p = np.arange(min(samples, len(patient_data))) / patient_sfreq
        
        # Plot healthy EEG
        for ch in range(min(channels, healthy_data.shape[1])):
            ax1.plot(time_h, healthy_data[:len(time_h), ch] * 1e6,
                    linewidth=2, alpha=0.8, label=f'Channel {ch+1}')
        
        ax1.set_title('Healthy EEG Signals', fontsize=16, pad=20)
        ax1.set_xlabel('Time (seconds)', fontsize=14)
        ax1.set_ylabel('Amplitude (μV)', fontsize=14)
        ax1.legend(loc='upper right', fontsize=12)
        ax1.grid(True, linestyle='--', alpha=0.6)
        ax1.set_ylim(-100, 100)
        ax1.tick_params(axis='both', which='major', labelsize=12)
        
        # Plot patient EEG
        for ch in range(min(channels, patient_data.shape[1])):
            ax2.plot(time_p, patient_data[:len(time_p), ch] * 1e6,
                    linewidth=2, alpha=0.8, label=f'Channel {ch+1}')
        
        ax2.set_title('Patient EEG Signals', fontsize=16, pad=20)
        ax2.set_xlabel('Time (seconds)', fontsize=14)
        ax2.legend(loc='upper right', fontsize=12)
        ax2.grid(True, linestyle='--', alpha=0.6)
        ax2.set_ylim(-100, 100)
        ax2.tick_params(axis='both', which='major', labelsize=12)
        
        plt.suptitle('EEG Signal Comparison (First 500 Samples)', fontsize=18, y=1.02)
        plt.tight_layout()
        plt.show()

    def plot_patient_response_categories(self):
        """Enhanced patient response categorization visualization with robust error handling"""
        if not self.results or 'Patient' not in self.results:
            print("No patient results available for visualization")
            return
    
        try:
            # Use the first available model's results
            model_name = next(iter(self.results['Patient']))
            results = self.results['Patient'][model_name]
            
            if 'y_true' not in results or 'y_pred' not in results:
                print("Missing required data in results")
                return
                
            y_true = results['y_true']
            y_pred = results['y_pred']
            
            # Calculate accuracy for each sample
            if len(np.unique(y_true)) == 2:
                # Binary classification
                accuracies = (y_pred == y_true).astype(float)
            else:
                # Multiclass classification
                accuracies = np.array([1.0 if pred == true else 0.0 
                                     for pred, true in zip(y_pred, y_true)])
            
            # Categorize patients with safe division
            categories = []
            for acc in accuracies:
                if acc >= 0.7:
                    categories.append("Good Response")
                elif 0.4 <= acc < 0.7:
                    categories.append("Medium Response")
                else:
                    categories.append("Poor Response")
            
            category_counts = pd.Series(categories).value_counts()
            
            # Create figure with proper layout
            fig = plt.figure(figsize=(18, 8))
            gs = GridSpec(1, 2, width_ratios=[1, 1.5])
            
            # Subplot 1: Pie chart with safe explode parameter
            ax1 = fig.add_subplot(gs[0])
            colors = ['#4CAF50', '#FFC107', '#F44336']
            
            # Ensure explode matches the number of categories
            explode = (0.05, 0.05, 0.05)[:len(category_counts)]
            
            # Handle case where we might have fewer than 3 categories
            if len(category_counts) < 3:
                colors = colors[:len(category_counts)]
                explode = explode[:len(category_counts)]
            
            wedges, texts, autotexts = ax1.pie(
                category_counts, 
                labels=category_counts.index, 
                autopct=lambda p: f'{p:.1f}%' if p > 0 else '',
                startangle=90, 
                colors=colors,
                explode=explode,
                textprops={'fontsize': 12}
            )
            
            for autotext in autotexts:
                autotext.set_color('white')
                autotext.set_fontweight('bold')
            
            ax1.set_title('Patient Response Distribution', fontsize=16, pad=20)
            
            # Subplot 2: Bar plot
            ax2 = fig.add_subplot(gs[1])
            barplot = sns.barplot(
                x=category_counts.index, 
                y=category_counts.values, 
                ax=ax2,
                palette=colors,
                saturation=0.8
            )
            
            ax2.set_title('Patient Response Categories', fontsize=16, pad=20)
            ax2.set_xlabel('Response Category', fontsize=14)
            ax2.set_ylabel('Number of Patients', fontsize=14)
            
            # Add count annotations
            for p in barplot.patches:
                height = p.get_height()
                if not np.isnan(height) and height > 0:
                    barplot.annotate(
                        f'{int(height)}',
                        (p.get_x() + p.get_width() / 2., height),
                        ha='center', va='center',
                        xytext=(0, 10),
                        textcoords='offset points',
                        fontsize=12,
                        fontweight='bold'
                    )
            
            # Add overall accuracy if available
            if 'accuracy' in results:
                overall_acc = results['accuracy']
                fig.text(
                    0.5, -0.05,
                    f'Model: {model_name} | Overall Accuracy: {overall_acc:.1%}',
                    ha='center', va='center', fontsize=14
                )
            
            plt.suptitle('Patient Response Categorization', fontsize=18, y=1.05)
            plt.tight_layout()
            plt.show()
            
        except Exception as e:
            print(f"Error generating patient response visualization: {str(e)}")
            import traceback
            traceback.print_exc()

    def plot_confusion_matrices(self, data_type='Patient'):
        """Enhanced confusion matrix visualization with better formatting"""
        if not self.results or data_type not in self.results:
            print(f"No results available for {data_type} data")
            return
            
        try:
            models = list(self.results[data_type].keys())
            num_models = len(models)
            
            if num_models == 0:
                print("No models available for visualization")
                return
                
            # Create figure with appropriate size
            fig, axes = plt.subplots(1, num_models, figsize=(6*num_models, 5))
            if num_models == 1:
                axes = [axes]
            
            for i, model_name in enumerate(models):
                model_results = self.results[data_type][model_name]
                
                if 'confusion_matrix' not in model_results:
                    print(f"No confusion matrix for {model_name}")
                    continue
                    
                cm = model_results['confusion_matrix']
                
                # Get class names
                if self.label_encoder:
                    classes = self.label_encoder.classes_
                else:
                    # Handle binary and multiclass cases
                    n_classes = cm.shape[0]
                    classes = [f'Class {i}' for i in range(n_classes)]
                    if n_classes == 2:
                        classes = ['Negative', 'Positive']
                
                # Normalize the confusion matrix
                cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
                
                # Plot with annotations
                sns.heatmap(
                    cm_normalized, 
                    annot=True, 
                    fmt='.2f',
                    cmap='Blues',
                    xticklabels=classes,
                    yticklabels=classes,
                    ax=axes[i],
                    cbar=False,
                    annot_kws={'fontsize': 10},
                    vmin=0, vmax=1
                )
                
                axes[i].set_title(f'{model_name}\nConfusion Matrix', fontsize=14)
                axes[i].set_xlabel('Predicted Label', fontsize=12)
                axes[i].set_ylabel('True Label', fontsize=12)
            
            plt.suptitle(f'Model Performance on {data_type} Data', fontsize=16, y=1.05)
            plt.tight_layout()
            plt.show()
            
        except Exception as e:
            print(f"Error generating confusion matrices: {str(e)}")
            import traceback
            traceback.print_exc()
            
    def plot_roc_curves_comparison(self):
        """Plot ROC curves for all models for comparison"""
        if not self.results or 'Patient' not in self.results:
            print("No patient results available")
            return
        
        plt.figure(figsize=(10, 8))
        
        for model_name, results in self.results['Patient'].items():
            if results.get('roc_auc') is not None:
                plt.plot(results['fpr'], results['tpr'],
                        label=f'{model_name} (AUC = {results["roc_auc"]:.2f})')
        
        plt.plot([0, 1], [0, 1], 'k--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiver Operating Characteristic Comparison')
        plt.legend(loc="lower right")
        plt.grid(True)
        plt.show()
        
    def plot_model_performance_comparison(self):
        """Enhanced model performance comparison visualization"""
        if not self.results or 'Patient' not in self.results:
            raise ValueError("Patient results not available")
            
        models = list(self.results['Patient'].keys())
        metrics = ['accuracy', 'precision', 'recall', 'f1-score']
        
        metrics_data = []
        for model in models:
            report = self.results['Patient'][model]['report']
            if isinstance(report, dict) and 'accuracy' in report:
                if 'macro avg' in report:
                    metrics_data.append({
                        'Model': model,
                        'Accuracy': report['accuracy'],
                        'Precision': report['macro avg']['precision'],
                        'Recall': report['macro avg']['recall'],
                        'F1-Score': report['macro avg']['f1-score']
                    })
                elif len(report.keys()) > 3:
                    metrics_data.append({
                        'Model': model,
                        'Accuracy': report['accuracy'],
                        'Precision': report['1']['precision'],
                        'Recall': report['1']['recall'],
                        'F1-Score': report['1']['f1-score']
                    })
        
        if not metrics_data:
            raise ValueError("No valid metric data found")
            
        df = pd.DataFrame(metrics_data)
        df_melted = df.melt(id_vars='Model', var_name='Metric', value_name='Score')
        
        plt.figure(figsize=(12, 6))
        barplot = sns.barplot(
            x='Model', 
            y='Score', 
            hue='Metric', 
            data=df_melted,
            palette='viridis',
            alpha=0.8
        )
        
        plt.title('Model Performance Metrics Comparison', fontsize=16, pad=20)
        plt.xlabel('Model', fontsize=14)
        plt.ylabel('Score', fontsize=14)
        plt.ylim(0, 1.1)
        plt.legend(title='Metric', bbox_to_anchor=(1.05, 1), loc='upper left')
        
        for p in barplot.patches:
            barplot.annotate(
                format(p.get_height(), '.2f'),
                (p.get_x() + p.get_width() / 2., p.get_height()),
                ha='center', va='center',
                xytext=(0, 10),
                textcoords='offset points',
                fontsize=10
            )
        
        plt.tight_layout()
        plt.show()

    def _get_error_messages(self):
        """Helper to collect error messages"""
        return self._error_messages

    def run_pipeline(self, config):
        """Complete EEG analysis pipeline with robust error handling"""
        results = None
        try:
            print("\n=== EEG Analysis Pipeline ===\n")
            
            # 1. Load data
            print("[1/7] Loading data...")
            healthy_data, healthy_labels, h_sfreq = self.load_dataset(
                config['healthy_path'],
                max_files=config.get('max_files'),
                max_duration=config.get('max_duration', None))
            patient_data, patient_labels, p_sfreq = self.load_dataset(
                config['patient_path'],
                is_patient=True,
                max_files=config.get('max_files'),
                max_duration=config.get('max_duration', None))
            
            # 2. Channel matching
            print("[2/7] Finding common channels...")
            if not config.get('skip_channel_matching', False):
                healthy_channels = self.collect_channel_names(config['healthy_path'])
                patient_channels = self.collect_channel_names(config['patient_path'], is_patient=True)
                self.common_channels = self.find_common_channels(healthy_channels, patient_channels)
            
            # 3. Preprocessing
            print("[3/7] Preprocessing data...")
            X_healthy, y_healthy = self.preprocess_data(
                healthy_data, healthy_labels,
                sfreq=h_sfreq,
                common_channels=self.common_channels)
            X_patients, y_patients = self.preprocess_data(
                patient_data, patient_labels,
                sfreq=p_sfreq,
                common_channels=self.common_channels)
            
            # 4. Load artifacts
            print("[4/7] Loading preprocessing artifacts...")
            self.load_preprocessing_artifacts(
                config['scaler_path'],
                config['label_encoder_path'])
            
            # 5. Apply transformations
            print("[5/7] Applying transformations...")
            if self.scaler:
                expected_features = self.scaler.n_features_in_
                X_healthy = self.match_features(X_healthy, expected_features)
                X_patients = self.match_features(X_patients, expected_features)
                X_healthy = self.scaler.transform(X_healthy)
                X_patients = self.scaler.transform(X_patients)
            
            if self.label_encoder:
                y_patients = self.label_encoder.transform(y_patients)
            
            # 6. Load models
            print("[6/7] Loading models...")
            self.load_models(config['model_paths'])
            
            # 7. Evaluate models (only on patient data)
            print("[7/7] Evaluating models...")
            self.evaluate_models(X_patients, y_patients, data_type='Patient')
            
            # Generate visualizations
            print("\n=== Generating Visualizations ===\n")
            try:
                plt.style.use('seaborn-v0_8-whitegrid')
            except:
                plt.style.use('ggplot')
            
            vis_funcs = [
                ('EEG Comparison', self.plot_eeg_comparison, [healthy_data, patient_data, h_sfreq, p_sfreq]),
                ('Patient Response', self.plot_patient_response_categories, []),
                ('Confusion Matrix', self.plot_confusion_matrices, ['Patient']),
                ('Model Performance', self.plot_model_performance_comparison, []),
                ('ROC Curve' , self.plot_roc_curves_comparison,[])
            ]
            
            for name, func, args in vis_funcs:
                try:
                    print(f"Generating {name} visualization...")
                    func(*args)
                    plt.close('all')
                except Exception as e:
                    error_msg = f"Failed to generate {name}: {str(e)}"
                    self._error_messages.append(error_msg)
                    print(error_msg)
            
            print("\n=== Pipeline Completed ===\n")
            if self._error_messages:
                print("Completed with warnings/errors (see above messages)")
            else:
                print("Completed successfully!")
            
            results = {
                'healthy_data': (X_healthy, y_healthy),
                'patient_data': (X_patients, y_patients),
                'results': self.results,
                'metadata': {
                    'healthy_samples': len(X_healthy),
                    'patient_samples': len(X_patients),
                    'common_channels': self.common_channels,
                    'features_per_channel': self.expected_features_per_channel,
                    'errors': self._error_messages
                }
            }
            
        except Exception as e:
            print(f"\n!!! Pipeline Failed !!!\nError: {str(e)}")
            import traceback
            traceback.print_exc()
            
        finally:
            if 'healthy_data' in locals():
                del healthy_data
            if 'patient_data' in locals():
                del patient_data
            gc.collect()
            
        return results


if __name__ == "__main__":
    channel_mapping = {
        'FP1': 'Fp1',
        'FP2': 'Fp2',
        'F3': 'F3',
        'F4': 'F4',
    }

    config = {
        'healthy_path': "F:/shivani/VSCode/ml/worked on dataset/3(final)/Healthy",
        'patient_path': "F:/shivani/VSCode/ml/worked on dataset/4/dataset/Patients",
        'model_paths': {
            "SVM": "F:/shivani/VSCode/ml/worked on dataset/3(final)/svm_model.pkl",
            "RF": "F:/shivani/VSCode/ml/worked on dataset/3(final)/rf_model.pkl",
            "XGB": "F:/shivani/VSCode/ml/worked on dataset/3(final)/xgb_model.pkl"
        },
        'scaler_path': "F:/shivani/VSCode/ml/worked on dataset/3(final)/preprocessing_artifacts/eeg_scaler.joblib",
        'label_encoder_path': "F:/shivani/VSCode/ml/worked on dataset/3(final)/preprocessing_artifacts/label_encoder.joblib",
        'max_files': 5,
    }

    pipeline = EEGPipeline()
    pipeline.set_channel_mapping(channel_mapping)
    results = pipeline.run_pipeline(config)

IndentationError: expected an indented block after function definition on line 287 (2664835516.py, line 288)

In [16]:
import os
import numpy as np
from tqdm import tqdm
import mne
import gc
from scipy.io import loadmat
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import (
    accuracy_score, classification_report, confusion_matrix,
    roc_curve, auc, precision_recall_curve, average_precision_score
)
import matplotlib.pyplot as plt
import seaborn as sns
import joblib
from scipy.signal import welch
import pywt
import pandas as pd
import re
from sklearn.model_selection import train_test_split

# Configure plotting
plt.rcParams['figure.figsize'] = [12, 8]
plt.rcParams['font.size'] = 12
sns.set_style("whitegrid")
sns.set_palette("colorblind")

class EEGPipeline:
    def __init__(self):
        self.models = {}
        self.results = {}
        self.scaler = None
        self.label_encoder = None
        self.common_channels = None
        self.feature_names = [
            'Delta Power', 'Theta Power', 'Alpha Power', 'Beta Power', 'Gamma Power',
            'Wavelet Mean 1', 'Wavelet Mean 2', 'Wavelet Mean 3', 'Wavelet Mean 4', 'Wavelet Mean 5',
            'Mean', 'Std Dev', 'Median', 'Skewness', 'Kurtosis', 'Hjorth Mobility', 'Hjorth Complexity'
        ]
        self.band_names = ['Delta', 'Theta', 'Alpha', 'Beta', 'Gamma']
        self.channel_mapping = {}
        self.expected_features_per_channel = 17  # Updated for new features
        self._error_messages = []

    def set_channel_mapping(self, mapping_dict):
        """Set manual channel name mapping between different naming conventions"""
        self.channel_mapping = mapping_dict

    def normalize_channel_name(self, channel_name):
        """Advanced channel name normalization with manual mapping support"""
        if isinstance(channel_name, (list, np.ndarray)):
            channel_name = channel_name[0]
        channel_name = str(channel_name).strip().upper()
        
        if channel_name in self.channel_mapping:
            return self.channel_mapping[channel_name]
        
        channel_name = re.sub(r'[^A-Z0-9]', '', channel_name)
        channel_name = re.sub(r'^CH', '', channel_name)
        channel_name = re.sub(r'^EEG', '', channel_name)
        channel_name = channel_name.lstrip('0')
        
        variations = {
            'FP1': 'Fp1', 'FP2': 'Fp2',
            'T3': 'T7', 'T4': 'T8',
            'T5': 'P7', 'T6': 'P8'
        }
        return variations.get(channel_name, channel_name)

    def get_channel_names_from_mat(self, mat_path):
        """Robust MAT file channel extraction supporting multiple formats"""
        try:
            mat_data = loadmat(mat_path)
            channels = []
            
            if 'data' in mat_data:
                data_struct = mat_data['data'][0][0]
                if 'channels' in data_struct.dtype.names:
                    channels = [str(ch[0]) for ch in data_struct['channels'][0]]
                elif 'chanlocs' in data_struct.dtype.names:
                    chanlocs = data_struct['chanlocs'][0]
                    channels = [str(chan['labels'][0]) for chan in chanlocs]
            
            elif 'EEG' in mat_data:
                eeg_struct = mat_data['EEG'][0][0]
                if 'chanlocs' in eeg_struct.dtype.names:
                    chanlocs = eeg_struct['chanlocs'][0]
                    channels = [str(chan['labels'][0]) for chan in chanlocs]
                elif 'chaninfo' in eeg_struct.dtype.names:
                    chaninfo = eeg_struct['chaninfo'][0][0]
                    if 'labels' in chaninfo.dtype.names:
                        channels = [str(ch[0]) for ch in chaninfo['labels'][0]]
            
            elif 'X' in mat_data and 'ch_names' in mat_data:
                channels = [str(ch[0]) for ch in mat_data['ch_names'][0]]
                
            return [self.normalize_channel_name(ch) for ch in channels if ch and str(ch).strip()]
        except Exception as e:
            error_msg = f"Error loading {mat_path}: {str(e)}"
            self._error_messages.append(error_msg)
            print(error_msg)
            return []

    def get_channel_names_from_vhdr(self, vhdr_path):
        """Extract channel names from BrainVision files with validation"""
        try:
            raw = mne.io.read_raw_brainvision(vhdr_path, preload=False, verbose=False)
            return [self.normalize_channel_name(ch) for ch in raw.ch_names]
        except Exception as e:
            error_msg = f"Error loading {vhdr_path}: {str(e)}"
            self._error_messages.append(error_msg)
            print(error_msg)
            return []

    def find_common_channels(self, healthy_channels, patient_channels):
        """Flexible channel matching with multiple strategies"""
        common = set(healthy_channels).intersection(patient_channels)
        
        if not common:
            healthy_set = set(healthy_channels)
            patient_set = set(patient_channels)
            common = healthy_set.intersection(patient_set)
            
        if not common:
            common_partial = set()
            for h_ch in healthy_set:
                for p_ch in patient_set:
                    if h_ch in p_ch or p_ch in h_ch:
                        common_partial.add(h_ch)
            if common_partial:
                print(f"Using partial channel matches: {common_partial}")
                return sorted(common_partial)
            
        print(f"Found {len(common)} common channels")
        return sorted(common)

    def collect_channel_names(self, folder_path, is_patient=False):
        """Collect channel names with extensive validation"""
        all_channels = set()
        files = [f for f in os.listdir(folder_path) if f.endswith('.vhdr' if not is_patient else '.mat')]
        
        if not files:
            raise FileNotFoundError(f"No valid files found in {folder_path}")
            
        for file in tqdm(files, desc=f"Collecting {'patient' if is_patient else 'healthy'} channels"):
            file_path = os.path.join(folder_path, file)
            channels = self.get_channel_names_from_mat(file_path) if is_patient else self.get_channel_names_from_vhdr(file_path)
            
            if not channels:
                print(f"Warning: No channels found in {file}")
                continue
                
            all_channels.update(channels)
            
        if not all_channels:
            raise ValueError(f"No channels collected from {folder_path}")
            
        return sorted(all_channels)

    def load_dataset(self, folder_path, is_patient=False, max_files=None, max_duration=None):
        """Robust dataset loading with comprehensive validation"""
        files = [f for f in os.listdir(folder_path) if f.endswith('.mat' if is_patient else '.vhdr')]
        loader_func = self.load_mat_data if is_patient else self.load_brainvision_data
            
        if not files:
            raise FileNotFoundError(f"No valid files found in {folder_path}")
            
        if max_files:
            files = files[:max_files]
            
        all_data = []
        all_labels = []
        sfreqs = []
        loaded_files = 0
        
        for file in tqdm(files, desc=f"Loading {'patient' if is_patient else 'healthy'} files"):
            file_path = os.path.join(folder_path, file)
            try:
                data, labels, sfreq = loader_func(file_path, max_duration)
                
                if len(data) > 0 and len(labels) > 0:
                    all_data.append(data)
                    all_labels.append(labels)
                    sfreqs.append(sfreq)
                    loaded_files += 1
                else:
                    print(f"Skipping {file} - no valid data")
            except Exception as e:
                error_msg = f"Error loading {file}: {str(e)}"
                self._error_messages.append(error_msg)
                print(error_msg)
                
            gc.collect()
            
        print(f"\nSuccessfully loaded {loaded_files}/{len(files)} files")
        
        if not all_data:
            raise ValueError("No valid data loaded - check file formats")
            
        avg_sfreq = np.mean(sfreqs) if sfreqs else (100 if is_patient else 250)
        return np.concatenate(all_data), np.concatenate(all_labels), avg_sfreq

    def load_brainvision_data(self, vhdr_path, max_duration=None):
        """Load BrainVision data with enhanced validation"""
        try:
            raw = mne.io.read_raw_brainvision(vhdr_path, preload=True, verbose=False)
            
            if max_duration:
                crop_end = min(max_duration, raw.times[-1])
                raw.crop(tmax=crop_end)
                
            # Apply standard preprocessing to match patient data
            raw.filter(0.5, 45, fir_design='firwin', phase='zero-double')
            raw.notch_filter(np.arange(50, 251, 50), filter_length='auto', phase='zero')
            
            data = raw.get_data().T.astype(np.float32)
            events, _ = mne.events_from_annotations(raw)
            labels = events[:, 2] if len(events) > 0 else np.zeros(len(data))
            
            return data, labels, raw.info['sfreq']
        except Exception as e:
            error_msg = f"Error loading {vhdr_path}: {str(e)}"
            self._error_messages.append(error_msg)
            print(error_msg)
            return np.array([]), np.array([]), None

    def load_mat_data(self, mat_path, max_duration=None, default_sfreq=100):
        """Flexible MAT file loader supporting multiple structures"""
        try:
            mat_data = loadmat(mat_path)
            eeg_data = None
            labels = None
            sfreq = default_sfreq
            
            if 'data' in mat_data:
                data_struct = mat_data['data'][0][0]
                if 'X' in data_struct.dtype.names:
                    eeg_data = data_struct['X']
                    if eeg_data.ndim > 2:
                        eeg_data = eeg_data.reshape(eeg_data.shape[0], -1)
                    if 'y' in data_struct.dtype.names:
                        labels = data_struct['y'].flatten()
                    if 'sfreq' in data_struct.dtype.names:
                        sfreq = float(data_struct['sfreq'][0][0])
                    elif 'Fs' in data_struct.dtype.names:
                        sfreq = float(data_struct['Fs'][0][0])
            
            elif 'EEG' in mat_data:
                eeg_struct = mat_data['EEG'][0][0]
                if 'data' in eeg_struct.dtype.names:
                    eeg_data = eeg_struct['data'].T
                if 'event' in eeg_struct.dtype.names:
                    events = eeg_struct['event'][0]
                    labels = np.array([ev[0]['type'][0] for ev in events])
                if 'srate' in eeg_struct.dtype.names:
                    sfreq = float(eeg_struct['srate'][0][0])
            
            elif 'X' in mat_data:
                eeg_data = mat_data['X']
                if 'y' in mat_data:
                    labels = mat_data['y'].flatten()
            
            if labels is None or len(np.unique(labels)) <= 1:
                labels = np.zeros(len(eeg_data)) if eeg_data is not None else np.array([])
            
            if eeg_data is None:
                raise ValueError("No EEG data found in MAT file")
            
            min_len = min(len(eeg_data), len(labels))
            eeg_data = eeg_data[:min_len]
            labels = labels[:min_len]
            
            if max_duration:
                max_samples = int(max_duration * sfreq)
                eeg_data = eeg_data[:max_samples]
                labels = labels[:max_samples]
                
            return eeg_data.astype(np.float32), labels, sfreq
        except Exception as e:
            error_msg = f"Error loading {mat_path}: {str(e)}"
            self._error_messages.append(error_msg)
            print(error_msg)
            return np.array([]), np.array([]), None

    def hjorth_parameters(self, signal):
        """Calculate Hjorth mobility and complexity parameters"""
        first_deriv = np.diff(signal)
        second_deriv = np.diff(signal, 2)
        
        var_zero = np.var(signal)
        var_d1 = np.var(first_deriv)
        var_d2 = np.var(second_deriv)
        
        mobility = np.sqrt(var_d1 / var_zero)
        complexity = np.sqrt(var_d2 / var_d1) / mobility
        
        return mobility, complexity

    def preprocess_data(self, data, labels, sfreq=250, common_channels=None):
        """Enhanced preprocessing with detailed debugging"""
        print("\n=== Starting Preprocessing ===")
        print(f"Input data shape: {data.shape if data is not None else 'None'}")
        print(f"Sample rate: {sfreq} Hz")
        
        try:
            # 1. Validate input data
            if data is None or len(data) == 0:
                raise ValueError("Empty data array received")
                
            if len(data.shape) != 2:
                raise ValueError(f"Expected 2D array, got {len(data.shape)}D array")
                
            n_channels = data.shape[1]
            print(f"Number of channels: {n_channels}")
            
            if n_channels == 0:
                raise ValueError("No channels found in data")
    
            # 2. Create RawArray
            ch_names = common_channels[:n_channels] if common_channels else [f'ch{i}' for i in range(n_channels)]
            print(f"First 3 channel names: {ch_names[:3]}")
            
            info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types='eeg')
            print("Created info structure")
            
            try:
                raw = mne.io.RawArray(data.T, info)
                print("Successfully created RawArray")
            except Exception as e:
                raise ValueError(f"Failed to create RawArray: {str(e)}")
    
            # 3. Apply filters
            nyquist = sfreq / 2
            print(f"Nyquist frequency: {nyquist} Hz")
            
            try:
                print("Applying bandpass filter...")
                raw.filter(0.5, min(45, nyquist-1), fir_design='firwin', phase='zero-double')
                
                print("Applying notch filter...")
                notch_freqs = [f for f in np.arange(50, 251, 50) if f < nyquist]
                if notch_freqs:
                    raw.notch_filter(notch_freqs, filter_length='auto', phase='zero')
            except Exception as e:
                raise ValueError(f"Filtering failed: {str(e)}")
    
            # 4. Create epochs
            try:
                print("Creating epochs...")
                events = mne.make_fixed_length_events(raw, duration=1.0)
                print(f"Created {len(events)} events")
                
                epochs = mne.Epochs(raw, events, tmin=0, tmax=1.0, baseline=None, preload=True)
                epochs_data = epochs.get_data()
                print(f"Epochs shape: {epochs_data.shape}")
                
                if len(epochs_data) == 0:
                    raise ValueError("No epochs created - check duration parameters")
            except Exception as e:
                raise ValueError(f"Epoch creation failed: {str(e)}")
    
            # 5. Feature extraction
            print("Extracting features...")
            X, y = self._extract_features(epochs_data, labels[:len(epochs_data)], sfreq, nyquist)
            print(f"Extracted features shape: {X.shape}")
            
            if len(X) == 0:
                raise ValueError("Feature extraction returned empty array")
                
            return X, y
            
        except Exception as e:
            error_msg = f"Preprocessing failed: {str(e)}"
            print(error_msg)
            self._error_messages.append(error_msg)
            return np.array([]), np.array([])

    def _extract_features(self, epochs_data, labels, sfreq, nyquist):
            """Internal feature extraction method"""
            features = []
            n_channels = epochs_data.shape[1]
            
            for epoch in epochs_data:
                epoch_features = []
                for channel in epoch:
                    # Frequency features
                    try:
                        freqs, psd = welch(channel, fs=sfreq, nperseg=min(256, len(channel)))
                        bands = {'delta': (0.5, 4), 'theta': (4, 8), 'alpha': (8, 12),
                                'beta': (12, 30), 'gamma': (30, min(45, nyquist-1))}
                        band_powers = [np.sum(psd[(freqs >= low) & (freqs <= high)]) 
                                     for low, high in bands.values()]
                    except:
                        band_powers = [0.0] * len(self.band_names)
        
                    # Wavelet features
                    try:
                        coeffs = pywt.wavedec(channel, 'db4', level=4)
                        wavelet_features = [np.mean(c) for c in coeffs[:5]]
                        if len(wavelet_features) < 5:
                            wavelet_features += [0.0] * (5 - len(wavelet_features))
                    except:
                        wavelet_features = [0.0] * 5
        
                    # Statistical features
                    stats = [
                        np.mean(channel),
                        np.std(channel),
                        np.median(channel),
                        pd.Series(channel).skew(),
                        pd.Series(channel).kurtosis()
                    ]
                    
                    # Hjorth parameters
                    try:
                        mobility, complexity = self.hjorth_parameters(channel)
                        hjorth_features = [mobility, complexity]
                    except:
                        hjorth_features = [0.0, 0.0]
        
                    epoch_features.extend(band_powers + wavelet_features + stats + hjorth_features)
                
                features.append(epoch_features)
            
            X = np.array(features)
            y = np.array(labels)
            
            # Handle empty case
            if len(X) == 0:
                return np.empty((0, self.expected_features_per_channel * n_channels)), np.array([])
            
            return X, y

    def load_preprocessing_artifacts(self, scaler_path, label_encoder_path):
        """Load preprocessing artifacts with dimension validation"""
        try:
            self.scaler = joblib.load(scaler_path)
            print(f"✓ Scaler loaded successfully (expecting {self.scaler.n_features_in_} features)")
        except Exception as e:
            error_msg = f"✗ Failed to load scaler: {str(e)}"
            self._error_messages.append(error_msg)
            print(error_msg)
            self.scaler = None
            
        try:
            self.label_encoder = joblib.load(label_encoder_path)
            print("✓ Label encoder loaded successfully")
        except Exception as e:
            error_msg = f"✗ Failed to load label encoder: {str(e)}"
            self._error_messages.append(error_msg)
            print(error_msg)
            self.label_encoder = None

    def match_features(self, X, expected_features):
        """Ensure feature matrix matches expected dimensions"""
        if X is None or len(X) == 0:
            return np.array([])
            
        if X.shape[1] == expected_features:
            return X
        elif X.shape[1] < expected_features:
            pad_width = ((0, 0), (0, expected_features - X.shape[1]))
            return np.pad(X, pad_width, mode='constant')
        else:
            return X[:, :expected_features]

    def load_models(self, model_paths):
        """Load models with comprehensive validation"""
        self.models = {}
        for name, path in model_paths.items():
            try:
                self.models[name] = joblib.load(path)
                print(f"✓ {name} model loaded successfully")
            except Exception as e:
                error_msg = f"✗ Failed to load {name} model: {str(e)}"
                self._error_messages.append(error_msg)
                print(error_msg)
                
        if not self.models:
            raise ValueError("No models were loaded successfully")
        return self.models

    def evaluate_models(self, X_data, y_data, data_type='Patient'):
        """Enhanced evaluation with better ROC curve handling"""
        self.results[data_type] = {}
        for name, model in self.models.items():
            try:
                y_pred = model.predict(X_data)
                y_proba = model.predict_proba(X_data)[:, 1] if hasattr(model, "predict_proba") else None
                
                acc = accuracy_score(y_data, y_pred)
                report = classification_report(y_data, y_pred, output_dict=True, zero_division=0)
                cm = confusion_matrix(y_data, y_pred)
                
                # ROC curve calculation
                roc_auc, fpr, tpr = None, None, None
                if y_proba is not None and len(np.unique(y_data)) == 2:
                    fpr, tpr, _ = roc_curve(y_data, y_proba)
                    roc_auc = auc(fpr, tpr)
                
                # Precision-Recall curve
                pr_auc, precision, recall = None, None, None
                if y_proba is not None and len(np.unique(y_data)) == 2:
                    precision, recall, _ = precision_recall_curve(y_data, y_proba)
                    pr_auc = average_precision_score(y_data, y_proba)
                
                self.results[data_type][name] = {
                    'accuracy': acc,
                    'report': report,
                    'confusion_matrix': cm,
                    'roc_auc': roc_auc,
                    'fpr': fpr,
                    'tpr': tpr,
                    'pr_auc': pr_auc,
                    'precision': precision,
                    'recall': recall,
                    'y_true': y_data,
                    'y_pred': y_pred,
                    'y_proba': y_proba
                }
            except Exception as e:
                error_msg = f"Error evaluating {name} on {data_type} data: {str(e)}"
                self._error_messages.append(error_msg)
                print(error_msg)

    def plot_eeg_comparison(self, healthy_data, patient_data, healthy_sfreq=250, patient_sfreq=100, samples=500, channels=3):
        """Enhanced EEG signal comparison"""
        plt.style.use('seaborn-v0_8-whitegrid')
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8), sharey=True)
        
        time_h = np.arange(min(samples, len(healthy_data))) / healthy_sfreq
        time_p = np.arange(min(samples, len(patient_data))) / patient_sfreq
        
        # Healthy EEG
        for ch in range(min(channels, healthy_data.shape[1])):
            ax1.plot(time_h, healthy_data[:len(time_h), ch] * 1e6,
                    linewidth=1.5, alpha=0.8, label=f'Ch{ch+1}')
        
        ax1.set_title('Healthy EEG', fontsize=16, pad=20)
        ax1.set_xlabel('Time (s)', fontsize=14)
        ax1.set_ylabel('Amplitude (μV)', fontsize=14)
        ax1.legend(loc='upper right', fontsize=10)
        ax1.grid(True, linestyle=':', alpha=0.5)
        ax1.set_ylim(-150, 150)
        
        # Patient EEG
        for ch in range(min(channels, patient_data.shape[1])):
            ax2.plot(time_p, patient_data[:len(time_p), ch] * 1e6,
                    linewidth=1.5, alpha=0.8, label=f'Ch{ch+1}')
        
        ax2.set_title('Patient EEG', fontsize=16, pad=20)
        ax2.set_xlabel('Time (s)', fontsize=14)
        ax2.legend(loc='upper right', fontsize=10)
        ax2.grid(True, linestyle=':', alpha=0.5)
        ax2.set_ylim(-150, 150)
        
        plt.suptitle('EEG Signal Comparison', fontsize=18, y=1.02)
        plt.tight_layout()
        plt.show()

    def plot_roc_curves(self):
        """Enhanced ROC curve plotting with AUC values"""
        if not self.results or 'Patient' not in self.results:
            print("No patient results available for ROC curves")
            return
            
        plt.figure(figsize=(10, 8))
        
        for model_name, results in self.results['Patient'].items():
            if results.get('roc_auc') is not None:
                plt.plot(results['fpr'], results['tpr'],
                        lw=2, label=f'{model_name} (AUC = {results["roc_auc"]:.2f})')
        
        plt.plot([0, 1], [0, 1], 'k--', lw=1)
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate', fontsize=12)
        plt.ylabel('True Positive Rate', fontsize=12)
        plt.title('Receiver Operating Characteristic', fontsize=16)
        plt.legend(loc="lower right", fontsize=10)
        plt.grid(True, alpha=0.3)
        plt.show()

    def plot_precision_recall_curves(self):
        """Plot precision-recall curves"""
        if not self.results or 'Patient' not in self.results:
            print("No patient results available for PR curves")
            return
            
        plt.figure(figsize=(10, 8))
        
        for model_name, results in self.results['Patient'].items():
            if results.get('pr_auc') is not None:
                plt.plot(results['recall'], results['precision'],
                        lw=2, label=f'{model_name} (AP = {results["pr_auc"]:.2f})')
        
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('Recall', fontsize=12)
        plt.ylabel('Precision', fontsize=12)
        plt.title('Precision-Recall Curve', fontsize=16)
        plt.legend(loc="upper right", fontsize=10)
        plt.grid(True, alpha=0.3)
        plt.show()

    def run_pipeline(self, config):
        """Enhanced pipeline with better error handling"""
        results = None
        try:
            print("\n=== EEG Analysis Pipeline ===\n")
            
            # 1. Load data
            print("[1/7] Loading data...")
            healthy_data, healthy_labels, h_sfreq = self.load_dataset(
                config['healthy_path'],
                max_files=config.get('max_files'),
                max_duration=config.get('max_duration', None))
            patient_data, patient_labels, p_sfreq = self.load_dataset(
                config['patient_path'],
                is_patient=True,
                max_files=config.get('max_files'),
                max_duration=config.get('max_duration', None))
            
            # 2. Channel matching
            print("[2/7] Finding common channels...")
            if not config.get('skip_channel_matching', False):
                healthy_channels = self.collect_channel_names(config['healthy_path'])
                patient_channels = self.collect_channel_names(config['patient_path'], is_patient=True)
                self.common_channels = self.find_common_channels(healthy_channels, patient_channels)
            
            # 3. Preprocessing
            print("[3/7] Preprocessing data...")
            X_healthy, y_healthy = self.preprocess_data(
                healthy_data, healthy_labels,
                sfreq=h_sfreq,
                common_channels=self.common_channels)
            X_patients, y_patients = self.preprocess_data(
                patient_data, patient_labels,
                sfreq=p_sfreq,
                common_channels=self.common_channels)
            
            # 4. Load artifacts
            print("[4/7] Loading preprocessing artifacts...")
            self.load_preprocessing_artifacts(
                config['scaler_path'],
                config['label_encoder_path'])
            
            # 5. Apply transformations
            print("[5/7] Applying transformations...")
            if self.scaler:
                # Validate data before transformation
                if len(X_healthy) == 0 or len(X_patients) == 0:
                    raise ValueError("Empty feature arrays - check preprocessing output")
                
                try:
                    X_healthy = self.scaler.transform(X_healthy)
                    X_patients = self.scaler.transform(X_patients)
                except Exception as e:
                    error_msg = f"Scaling failed: {str(e)}"
                    self._error_messages.append(error_msg)
                    raise ValueError(error_msg)
            
            if self.label_encoder:
                try:
                    y_patients = self.label_encoder.transform(y_patients)
                except Exception as e:
                    error_msg = f"Label encoding failed: {str(e)}"
                    self._error_messages.append(error_msg)
                    raise ValueError(error_msg)
            
            # 6. Load models
            print("[6/7] Loading models...")
            self.load_models(config['model_paths'])
            
            # 7. Evaluate models
            print("[7/7] Evaluating models...")
            self.evaluate_models(X_patients, y_patients, data_type='Patient')
            
            # Generate visualizations
            print("\n=== Generating Visualizations ===\n")
            try:
                plt.style.use('seaborn-v0_8-whitegrid')
            except:
                plt.style.use('ggplot')
            
            vis_funcs = [
                ('EEG Comparison', self.plot_eeg_comparison, [healthy_data, patient_data, h_sfreq, p_sfreq]),
                ('ROC Curves', self.plot_roc_curves, []),
                ('PR Curves', self.plot_precision_recall_curves, [])
            ]
            
            for name, func, args in vis_funcs:
                try:
                    print(f"Generating {name} visualization...")
                    func(*args)
                    plt.close('all')
                except Exception as e:
                    error_msg = f"Failed to generate {name}: {str(e)}"
                    self._error_messages.append(error_msg)
                    print(error_msg)
            
            print("\n=== Pipeline Completed ===\n")
            if self._error_messages:
                print("Completed with warnings/errors (see above messages)")
            else:
                print("Completed successfully!")
            
            results = {
                'healthy_data': (X_healthy, y_healthy),
                'patient_data': (X_patients, y_patients),
                'results': self.results,
                'metadata': {
                    'healthy_samples': len(X_healthy),
                    'patient_samples': len(X_patients),
                    'common_channels': self.common_channels,
                    'features_per_channel': self.expected_features_per_channel,
                    'errors': self._error_messages
                }
            }
            
        except Exception as e:
            print(f"\n!!! Pipeline Failed !!!\nError: {str(e)}")
            import traceback
            traceback.print_exc()
            
        finally:
            if 'healthy_data' in locals():
                del healthy_data
            if 'patient_data' in locals():
                del patient_data
            gc.collect()
            
        return results


if __name__ == "__main__":
    channel_mapping = {
        'FP1': 'Fp1',
        'FP2': 'Fp2',
        'F3': 'F3',
        'F4': 'F4',
    }

    config = {
        'healthy_path': "F:/shivani/VSCode/ml/worked on dataset/3(final)/Healthy",
        'patient_path': "F:/shivani/VSCode/ml/worked on dataset/4/dataset/Patients",
        'model_paths': {
            "SVM": "F:/shivani/VSCode/ml/worked on dataset/3(final)/svm_model.pkl",
            "RF": "F:/shivani/VSCode/ml/worked on dataset/3(final)/rf_model.pkl",
            "XGB": "F:/shivani/VSCode/ml/worked on dataset/3(final)/xgb_model.pkl"
        },
        'scaler_path': "F:/shivani/VSCode/ml/worked on dataset/3(final)/preprocessing_artifacts/eeg_scaler.joblib",
        'label_encoder_path': "F:/shivani/VSCode/ml/worked on dataset/3(final)/preprocessing_artifacts/label_encoder.joblib",
        'max_files': 5,
        'skip_channel_matching': True  # Skip dynamic channel matching
    }

    pipeline = EEGPipeline()
    pipeline.set_channel_mapping(channel_mapping)
    results = pipeline.run_pipeline(config)


=== EEG Analysis Pipeline ===

[1/7] Loading data...


Loading healthy files:   0%|                                                                     | 0/5 [00:00<?, ?it/s]

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 45 Hz

FIR filter parameters
---------------------
Designing a two-pass forward and reverse, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-12 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 45.00 Hz
- Upper transition bandwidth: 11.25 Hz (-12 dB cutoff frequency: 50.62 Hz)
- Filter length: 16501 samples (6.600 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    5.4s


Filtering raw data in 1 contiguous segment
Setting up band-stop filter

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower transition bandwidth: 0.50 Hz
- Upper transition bandwidth: 0.50 Hz
- Filter length: 16501 samples (6.600 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    2.6s


Used Annotations descriptions: ['Comment/no USB Connection to actiCAP', 'New Segment/', 'Stimulus/S  1', 'Stimulus/S200', 'Stimulus/S210']


Loading healthy files:  20%|████████████▏                                                | 1/5 [00:36<02:24, 36.10s/it]

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 45 Hz

FIR filter parameters
---------------------
Designing a two-pass forward and reverse, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-12 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 45.00 Hz
- Upper transition bandwidth: 11.25 Hz (-12 dB cutoff frequency: 50.62 Hz)
- Filter length: 16501 samples (6.600 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    5.0s


Filtering raw data in 1 contiguous segment
Setting up band-stop filter

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower transition bandwidth: 0.50 Hz
- Upper transition bandwidth: 0.50 Hz
- Filter length: 16501 samples (6.600 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    2.6s


Used Annotations descriptions: ['Comment/no USB Connection to actiCAP', 'New Segment/', 'Stimulus/S  1', 'Stimulus/S200', 'Stimulus/S210']


Loading healthy files:  40%|████████████████████████▍                                    | 2/5 [01:11<01:46, 35.54s/it]

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 45 Hz

FIR filter parameters
---------------------
Designing a two-pass forward and reverse, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-12 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 45.00 Hz
- Upper transition bandwidth: 11.25 Hz (-12 dB cutoff frequency: 50.62 Hz)
- Filter length: 16501 samples (6.600 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    5.2s


Filtering raw data in 1 contiguous segment
Setting up band-stop filter

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower transition bandwidth: 0.50 Hz
- Upper transition bandwidth: 0.50 Hz
- Filter length: 16501 samples (6.600 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    2.6s


Used Annotations descriptions: ['Comment/no USB Connection to actiCAP', 'New Segment/', 'Stimulus/S  1', 'Stimulus/S200', 'Stimulus/S210']


Loading healthy files:  60%|████████████████████████████████████▌                        | 3/5 [01:46<01:11, 35.52s/it]

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 45 Hz

FIR filter parameters
---------------------
Designing a two-pass forward and reverse, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-12 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 45.00 Hz
- Upper transition bandwidth: 11.25 Hz (-12 dB cutoff frequency: 50.62 Hz)
- Filter length: 16501 samples (6.600 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    5.4s


Filtering raw data in 1 contiguous segment
Setting up band-stop filter

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower transition bandwidth: 0.50 Hz
- Upper transition bandwidth: 0.50 Hz
- Filter length: 16501 samples (6.600 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    2.9s


Used Annotations descriptions: ['Comment/no USB Connection to actiCAP', 'New Segment/', 'Stimulus/S  1', 'Stimulus/S200', 'Stimulus/S210']


Loading healthy files:  80%|████████████████████████████████████████████████▊            | 4/5 [02:22<00:35, 35.79s/it]

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 45 Hz

FIR filter parameters
---------------------
Designing a two-pass forward and reverse, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-12 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 45.00 Hz
- Upper transition bandwidth: 11.25 Hz (-12 dB cutoff frequency: 50.62 Hz)
- Filter length: 16501 samples (6.600 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    5.3s


Filtering raw data in 1 contiguous segment
Setting up band-stop filter

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower transition bandwidth: 0.50 Hz
- Upper transition bandwidth: 0.50 Hz
- Filter length: 16501 samples (6.600 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    2.6s


Used Annotations descriptions: ['Comment/actiCAP Data On', 'New Segment/', 'Stimulus/S  1', 'Stimulus/S200', 'Stimulus/S210']


Loading healthy files: 100%|█████████████████████████████████████████████████████████████| 5/5 [02:58<00:00, 35.79s/it]



Successfully loaded 5/5 files


Loading patient files: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:02<00:00,  2.11it/s]



Successfully loaded 5/5 files
[2/7] Finding common channels...
[3/7] Preprocessing data...

=== Starting Preprocessing ===
Input data shape: (12704800, 62)
Sample rate: 2500.0 Hz
Number of channels: 62
First 3 channel names: ['ch0', 'ch1', 'ch2']
Created info structure
Creating RawArray with float64 data, n_channels=62, n_times=12704800
    Range : 0 ... 12704799 =      0.000 ...  5081.920 secs
Ready.
Successfully created RawArray
Nyquist frequency: 1250.0 Hz
Applying bandpass filter...
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 45 Hz

FIR filter parameters
---------------------
Designing a two-pass forward and reverse, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-12 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 45.00 Hz
- Upper transition bandwidth: 11.25 H

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:   29.9s


Applying notch filter...
Filtering raw data in 1 contiguous segment
Setting up band-stop filter

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower transition bandwidth: 0.50 Hz
- Upper transition bandwidth: 0.50 Hz
- Filter length: 16501 samples (6.600 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:   19.6s


Creating epochs...
Created 5081 events
Not setting metadata
5081 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 5081 events and 2501 original time points ...
0 bad epochs dropped
Epochs shape: (5081, 62, 2501)
Extracting features...
Extracted features shape: (5081, 1054)

=== Starting Preprocessing ===
Input data shape: (1738520, 8)
Sample rate: 100.0 Hz
Number of channels: 8
First 3 channel names: ['ch0', 'ch1', 'ch2']
Created info structure
Creating RawArray with float64 data, n_channels=8, n_times=1738520
    Range : 0 ... 1738519 =      0.000 ... 17385.190 secs
Ready.
Successfully created RawArray
Nyquist frequency: 50.0 Hz
Applying bandpass filter...
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 45 Hz

FIR filter parameters
---------------------
Designing a two-pass forward and reverse, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamm



Extracted features shape: (17385, 136)
[4/7] Loading preprocessing artifacts...
✓ Scaler loaded successfully (expecting 992 features)
✓ Label encoder loaded successfully
[5/7] Applying transformations...

!!! Pipeline Failed !!!
Error: Scaling failed: X has 1054 features, but StandardScaler is expecting 992 features as input.


Traceback (most recent call last):
  File "C:\Users\shivani\AppData\Local\Temp\ipykernel_7468\3810426759.py", line 655, in run_pipeline
    X_healthy = self.scaler.transform(X_healthy)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\shivani\AppData\Local\Programs\Python\Python312\Lib\site-packages\sklearn\utils\_set_output.py", line 319, in wrapped
    data_to_wrap = f(self, X, *args, **kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\shivani\AppData\Local\Programs\Python\Python312\Lib\site-packages\sklearn\preprocessing\_data.py", line 1062, in transform
    X = validate_data(
        ^^^^^^^^^^^^^^
  File "C:\Users\shivani\AppData\Local\Programs\Python\Python312\Lib\site-packages\sklearn\utils\validation.py", line 2965, in validate_data
    _check_n_features(_estimator, X, reset=reset)
  File "C:\Users\shivani\AppData\Local\Programs\Python\Python312\Lib\site-packages\sklearn\utils\validation.py", line 2829, in _check_n_features
    raise Value

In [7]:
#2
import os
import numpy as np
from tqdm import tqdm
import mne
import gc
from scipy.io import loadmat
from sklearn.preprocessing import StandardScaler, LabelEncoder, RobustScaler
from sklearn.metrics import (
    accuracy_score, classification_report, confusion_matrix,
    roc_curve, auc, precision_recall_curve, average_precision_score
)
import matplotlib.pyplot as plt
import seaborn as sns
import joblib
from scipy.signal import welch
import pywt
import pandas as pd
from matplotlib.gridspec import GridSpec
import re
from sklearn.model_selection import train_test_split
from matplotlib.ticker import FormatStrFormatter
from sklearn.decomposition import PCA
from imblearn.over_sampling import SMOTE
from sklearn.feature_selection import SelectKBest, f_classif

# Configure plotting
plt.rcParams['figure.figsize'] = [12, 8]
plt.rcParams['font.size'] = 12
sns.set_style("whitegrid")
sns.set_palette("colorblind")

class EEGPipeline:
    def __init__(self):
        self.models = {}
        self.results = {}
        self.scaler = None
        self.label_encoder = None
        self.common_channels = None
        self.feature_names = [
            'Delta Power', 'Theta Power', 'Alpha Power', 'Beta Power', 'Gamma Power',
            'Wavelet Mean 1', 'Wavelet Mean 2', 'Wavelet Mean 3', 'Wavelet Mean 4', 'Wavelet Mean 5',
            'Mean', 'Std Dev', 'Median', 'Skewness', 'Kurtosis', 'Hjorth Mobility', 'Hjorth Complexity'
        ]
        self.band_names = ['Delta', 'Theta', 'Alpha', 'Beta', 'Gamma']
        self.channel_mapping = {}
        self.expected_features_per_channel = 17  # Updated for new features
        self._error_messages = []
        self.pca = None
        self.feature_selector = None

    def set_channel_mapping(self, mapping_dict):
        """Set manual channel name mapping between different naming conventions"""
        self.channel_mapping = mapping_dict

    def normalize_channel_name(self, channel_name):
        """Advanced channel name normalization with manual mapping support"""
        if isinstance(channel_name, (list, np.ndarray)):
            channel_name = channel_name[0]
        channel_name = str(channel_name).strip().upper()
        
        if channel_name in self.channel_mapping:
            return self.channel_mapping[channel_name]
        
        channel_name = re.sub(r'[^A-Z0-9]', '', channel_name)
        channel_name = re.sub(r'^CH', '', channel_name)
        channel_name = re.sub(r'^EEG', '', channel_name)
        channel_name = channel_name.lstrip('0')
        
        variations = {
            'FP1': 'Fp1', 'FP2': 'Fp2',
            'T3': 'T7', 'T4': 'T8',
            'T5': 'P7', 'T6': 'P8'
        }
        return variations.get(channel_name, channel_name)

    def get_channel_names_from_mat(self, mat_path):
        """Robust MAT file channel extraction supporting multiple formats"""
        try:
            mat_data = loadmat(mat_path)
            channels = []
            
            if 'data' in mat_data:
                data_struct = mat_data['data'][0][0]
                if 'channels' in data_struct.dtype.names:
                    channels = [str(ch[0]) for ch in data_struct['channels'][0]]
                elif 'chanlocs' in data_struct.dtype.names:
                    chanlocs = data_struct['chanlocs'][0]
                    channels = [str(chan['labels'][0]) for chan in chanlocs]
            
            elif 'EEG' in mat_data:
                eeg_struct = mat_data['EEG'][0][0]
                if 'chanlocs' in eeg_struct.dtype.names:
                    chanlocs = eeg_struct['chanlocs'][0]
                    channels = [str(chan['labels'][0]) for chan in chanlocs]
                elif 'chaninfo' in eeg_struct.dtype.names:
                    chaninfo = eeg_struct['chaninfo'][0][0]
                    if 'labels' in chaninfo.dtype.names:
                        channels = [str(ch[0]) for ch in chaninfo['labels'][0]]
            
            elif 'X' in mat_data and 'ch_names' in mat_data:
                channels = [str(ch[0]) for ch in mat_data['ch_names'][0]]
                
            return [self.normalize_channel_name(ch) for ch in channels if ch and str(ch).strip()]
        except Exception as e:
            error_msg = f"Error loading {mat_path}: {str(e)}"
            self._error_messages.append(error_msg)
            print(error_msg)
            return []

    def get_channel_names_from_vhdr(self, vhdr_path):
        """Extract channel names from BrainVision files with validation"""
        try:
            raw = mne.io.read_raw_brainvision(vhdr_path, preload=False, verbose=False)
            return [self.normalize_channel_name(ch) for ch in raw.ch_names]
        except Exception as e:
            error_msg = f"Error loading {vhdr_path}: {str(e)}"
            self._error_messages.append(error_msg)
            print(error_msg)
            return []

    def find_common_channels(self, healthy_channels, patient_channels):
        """Flexible channel matching with multiple strategies"""
        common = set(healthy_channels).intersection(patient_channels)
        
        if not common:
            healthy_set = set(healthy_channels)
            patient_set = set(patient_channels)
            common = healthy_set.intersection(patient_set)
            
        if not common:
            common_partial = set()
            for h_ch in healthy_set:
                for p_ch in patient_set:
                    if h_ch in p_ch or p_ch in h_ch:
                        common_partial.add(h_ch)
            if common_partial:
                print(f"Using partial channel matches: {common_partial}")
                return sorted(common_partial)
            
        print(f"Found {len(common)} common channels")
        return sorted(common)

    def collect_channel_names(self, folder_path, is_patient=False):
        """Collect channel names with extensive validation"""
        all_channels = set()
        files = [f for f in os.listdir(folder_path) if f.endswith('.vhdr' if not is_patient else '.mat')]
        
        if not files:
            raise FileNotFoundError(f"No valid files found in {folder_path}")
            
        for file in tqdm(files, desc=f"Collecting {'patient' if is_patient else 'healthy'} channels"):
            file_path = os.path.join(folder_path, file)
            channels = self.get_channel_names_from_mat(file_path) if is_patient else self.get_channel_names_from_vhdr(file_path)
            
            if not channels:
                print(f"Warning: No channels found in {file}")
                continue
                
            all_channels.update(channels)
            
        if not all_channels:
            raise ValueError(f"No channels collected from {folder_path}")
            
        return sorted(all_channels)

    def load_dataset(self, folder_path, is_patient=False, max_files=None, max_duration=None):
        """Robust dataset loading with comprehensive validation"""
        files = [f for f in os.listdir(folder_path) if f.endswith('.mat' if is_patient else '.vhdr')]
        loader_func = self.load_mat_data if is_patient else self.load_brainvision_data
            
        if not files:
            raise FileNotFoundError(f"No valid files found in {folder_path}")
            
        if max_files:
            files = files[:max_files]
            
        all_data = []
        all_labels = []
        sfreqs = []
        loaded_files = 0
        
        for file in tqdm(files, desc=f"Loading {'patient' if is_patient else 'healthy'} files"):
            file_path = os.path.join(folder_path, file)
            try:
                data, labels, sfreq = loader_func(file_path, max_duration)
                
                if len(data) > 0 and len(labels) > 0:
                    all_data.append(data)
                    all_labels.append(labels)
                    sfreqs.append(sfreq)
                    loaded_files += 1
                else:
                    print(f"Skipping {file} - no valid data")
            except Exception as e:
                error_msg = f"Error loading {file}: {str(e)}"
                self._error_messages.append(error_msg)
                print(error_msg)
                
            gc.collect()
            
        print(f"\nSuccessfully loaded {loaded_files}/{len(files)} files")
        
        if not all_data:
            raise ValueError("No valid data loaded - check file formats")
            
        avg_sfreq = np.mean(sfreqs) if sfreqs else (100 if is_patient else 250)
        return np.concatenate(all_data), np.concatenate(all_labels), avg_sfreq

    def load_brainvision_data(self, vhdr_path, max_duration=None):
        """Load BrainVision data with enhanced validation"""
        try:
            raw = mne.io.read_raw_brainvision(vhdr_path, preload=True, verbose=False)
            
            if max_duration:
                crop_end = min(max_duration, raw.times[-1])
                raw.crop(tmax=crop_end)
                
            # Apply standard preprocessing to match patient data
            raw.filter(0.5, 45, fir_design='firwin', phase='zero-double')
            raw.notch_filter(np.arange(50, 251, 50), filter_length='auto', phase='zero')
            
            data = raw.get_data().T.astype(np.float32)
            events, _ = mne.events_from_annotations(raw)
            labels = events[:, 2] if len(events) > 0 else np.zeros(len(data))
            
            return data, labels, raw.info['sfreq']
        except Exception as e:
            error_msg = f"Error loading {vhdr_path}: {str(e)}"
            self._error_messages.append(error_msg)
            print(error_msg)
            return np.array([]), np.array([]), None

    def load_mat_data(self, mat_path, max_duration=None, default_sfreq=100):
        """Flexible MAT file loader supporting multiple structures"""
        try:
            mat_data = loadmat(mat_path)
            eeg_data = None
            labels = None
            sfreq = default_sfreq
            
            if 'data' in mat_data:
                data_struct = mat_data['data'][0][0]
                if 'X' in data_struct.dtype.names:
                    eeg_data = data_struct['X']
                    if eeg_data.ndim > 2:
                        eeg_data = eeg_data.reshape(eeg_data.shape[0], -1)
                    if 'y' in data_struct.dtype.names:
                        labels = data_struct['y'].flatten()
                    if 'sfreq' in data_struct.dtype.names:
                        sfreq = float(data_struct['sfreq'][0][0])
                    elif 'Fs' in data_struct.dtype.names:
                        sfreq = float(data_struct['Fs'][0][0])
            
            elif 'EEG' in mat_data:
                eeg_struct = mat_data['EEG'][0][0]
                if 'data' in eeg_struct.dtype.names:
                    eeg_data = eeg_struct['data'].T
                if 'event' in eeg_struct.dtype.names:
                    events = eeg_struct['event'][0]
                    labels = np.array([ev[0]['type'][0] for ev in events])
                if 'srate' in eeg_struct.dtype.names:
                    sfreq = float(eeg_struct['srate'][0][0])
            
            elif 'X' in mat_data:
                eeg_data = mat_data['X']
                if 'y' in mat_data:
                    labels = mat_data['y'].flatten()
            
            if labels is None or len(np.unique(labels)) <= 1:
                labels = np.zeros(len(eeg_data)) if eeg_data is not None else np.array([])
            
            if eeg_data is None:
                raise ValueError("No EEG data found in MAT file")
            
            min_len = min(len(eeg_data), len(labels))
            eeg_data = eeg_data[:min_len]
            labels = labels[:min_len]
            
            if max_duration:
                max_samples = int(max_duration * sfreq)
                eeg_data = eeg_data[:max_samples]
                labels = labels[:max_samples]
                
            return eeg_data.astype(np.float32), labels, sfreq
        except Exception as e:
            error_msg = f"Error loading {mat_path}: {str(e)}"
            self._error_messages.append(error_msg)
            print(error_msg)
            return np.array([]), np.array([]), None

    def hjorth_parameters(self, signal):
        """Calculate Hjorth mobility and complexity parameters"""
        first_deriv = np.diff(signal)
        second_deriv = np.diff(signal, 2)
        
        var_zero = np.var(signal)
        var_d1 = np.var(first_deriv)
        var_d2 = np.var(second_deriv)
        
        mobility = np.sqrt(var_d1 / var_zero)
        complexity = np.sqrt(var_d2 / var_d1) / mobility
        
        return mobility, complexity

    def preprocess_data(self, data, labels, sfreq=250, common_channels=None):
        """Enhanced preprocessing with detailed debugging"""
        print("\n=== Starting Preprocessing ===")
        print(f"Input data shape: {data.shape if data is not None else 'None'}")
        print(f"Sample rate: {sfreq} Hz")
        
        try:
            # 1. Validate input data
            if data is None or len(data) == 0:
                raise ValueError("Empty data array received")
                
            if len(data.shape) != 2:
                raise ValueError(f"Expected 2D array, got {len(data.shape)}D array")
                
            n_channels = data.shape[1]
            print(f"Number of channels: {n_channels}")
            
            if n_channels == 0:
                raise ValueError("No channels found in data")
    
            # 2. Create RawArray
            ch_names = common_channels[:n_channels] if common_channels else [f'ch{i}' for i in range(n_channels)]
            print(f"First 3 channel names: {ch_names[:3]}")
            
            info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types='eeg')
            print("Created info structure")
            
            try:
                raw = mne.io.RawArray(data.T, info)
                print("Successfully created RawArray")
            except Exception as e:
                raise ValueError(f"Failed to create RawArray: {str(e)}")
    
            # 3. Apply filters
            nyquist = sfreq / 2
            print(f"Nyquist frequency: {nyquist} Hz")
            
            try:
                print("Applying bandpass filter...")
                raw.filter(0.5, min(45, nyquist-1), fir_design='firwin', phase='zero-double')
                
                print("Applying notch filter...")
                notch_freqs = [f for f in np.arange(50, 251, 50) if f < nyquist]
                if notch_freqs:
                    raw.notch_filter(notch_freqs, filter_length='auto', phase='zero')
            except Exception as e:
                raise ValueError(f"Filtering failed: {str(e)}")
    
            # 4. Create epochs
            try:
                print("Creating epochs...")
                events = mne.make_fixed_length_events(raw, duration=1.0)
                print(f"Created {len(events)} events")
                
                epochs = mne.Epochs(raw, events, tmin=0, tmax=1.0, baseline=None, preload=True)
                epochs_data = epochs.get_data()
                print(f"Epochs shape: {epochs_data.shape}")
                
                if len(epochs_data) == 0:
                    raise ValueError("No epochs created - check duration parameters")
            except Exception as e:
                raise ValueError(f"Epoch creation failed: {str(e)}")
    
            # 5. Feature extraction
            print("Extracting features...")
            X, y = self._extract_features(epochs_data, labels[:len(epochs_data)], sfreq, nyquist)
            print(f"Extracted features shape: {X.shape}")
            
            if len(X) == 0:
                raise ValueError("Feature extraction returned empty array")
                
            return X, y
            
        except Exception as e:
            error_msg = f"Preprocessing failed: {str(e)}"
            print(error_msg)
            self._error_messages.append(error_msg)
            return np.array([]), np.array([])

    def _extract_features(self, epochs_data, labels, sfreq, nyquist):
            """Internal feature extraction method"""
            features = []
            n_channels = epochs_data.shape[1]
            
            for epoch in epochs_data:
                epoch_features = []
                for channel in epoch:
                    # Frequency features
                    try:
                        freqs, psd = welch(channel, fs=sfreq, nperseg=min(256, len(channel)))
                        bands = {'delta': (0.5, 4), 'theta': (4, 8), 'alpha': (8, 12),
                                'beta': (12, 30), 'gamma': (30, min(45, nyquist-1))}
                        band_powers = [np.sum(psd[(freqs >= low) & (freqs <= high)]) 
                                     for low, high in bands.values()]
                    except:
                        band_powers = [0.0] * len(self.band_names)
        
                    # Wavelet features
                    try:
                        coeffs = pywt.wavedec(channel, 'db4', level=4)
                        wavelet_features = [np.mean(c) for c in coeffs[:5]]
                        if len(wavelet_features) < 5:
                            wavelet_features += [0.0] * (5 - len(wavelet_features))
                    except:
                        wavelet_features = [0.0] * 5
        
                    # Statistical features
                    stats = [
                        np.mean(channel),
                        np.std(channel),
                        np.median(channel),
                        pd.Series(channel).skew(),
                        pd.Series(channel).kurtosis()
                    ]
                    
                    # Hjorth parameters
                    try:
                        mobility, complexity = self.hjorth_parameters(channel)
                        hjorth_features = [mobility, complexity]
                    except:
                        hjorth_features = [0.0, 0.0]
        
                    epoch_features.extend(band_powers + wavelet_features + stats + hjorth_features)
                
                features.append(epoch_features)
            
            X = np.array(features)
            y = np.array(labels)
            
            # Handle empty case
            if len(X) == 0:
                return np.empty((0, self.expected_features_per_channel * n_channels)), np.array([])
            
            return X, y

    def load_preprocessing_artifacts(self, scaler_path, label_encoder_path):
        """Load preprocessing artifacts with dimension validation"""
        try:
            self.scaler = joblib.load(scaler_path)
            print(f"✓ Scaler loaded successfully (expecting {self.scaler.n_features_in_} features)")
        except Exception as e:
            error_msg = f"✗ Failed to load scaler: {str(e)}"
            self._error_messages.append(error_msg)
            print(error_msg)
            self.scaler = None
            
        try:
            self.label_encoder = joblib.load(label_encoder_path)
            print("✓ Label encoder loaded successfully")
        except Exception as e:
            error_msg = f"✗ Failed to load label encoder: {str(e)}"
            self._error_messages.append(error_msg)
            print(error_msg)
            self.label_encoder = None

    def match_features(self, X, expected_features):
        """Ensure feature matrix matches expected dimensions"""
        if X is None or len(X) == 0:
            return np.array([])
            
        if X.shape[1] == expected_features:
            return X
        elif X.shape[1] < expected_features:
            pad_width = ((0, 0), (0, expected_features - X.shape[1]))
            return np.pad(X, pad_width, mode='constant')
        else:
            return X[:, :expected_features]

    def load_models(self, model_paths):
        """Load models with comprehensive validation"""
        self.models = {}
        for name, path in model_paths.items():
            try:
                self.models[name] = joblib.load(path)
                print(f"✓ {name} model loaded successfully")
            except Exception as e:
                error_msg = f"✗ Failed to load {name} model: {str(e)}"
                self._error_messages.append(error_msg)
                print(error_msg)
                
        if not self.models:
            raise ValueError("No models were loaded successfully")
        return self.models

    def evaluate_models(self, X_data, y_data, data_type='Patient'):
        """Enhanced evaluation with better ROC curve handling"""
        self.results[data_type] = {}
        for name, model in self.models.items():
            try:
                y_pred = model.predict(X_data)
                y_proba = model.predict_proba(X_data)[:, 1] if hasattr(model, "predict_proba") else None
                
                acc = accuracy_score(y_data, y_pred)
                report = classification_report(y_data, y_pred, output_dict=True, zero_division=0)
                cm = confusion_matrix(y_data, y_pred)
                
                # ROC curve calculation
                roc_auc, fpr, tpr = None, None, None
                if y_proba is not None and len(np.unique(y_data)) == 2:
                    fpr, tpr, _ = roc_curve(y_data, y_proba)
                    roc_auc = auc(fpr, tpr)
                
                # Precision-Recall curve
                pr_auc, precision, recall = None, None, None
                if y_proba is not None and len(np.unique(y_data)) == 2:
                    precision, recall, _ = precision_recall_curve(y_data, y_proba)
                    pr_auc = average_precision_score(y_data, y_proba)
                
                self.results[data_type][name] = {
                    'accuracy': acc,
                    'report': report,
                    'confusion_matrix': cm,
                    'roc_auc': roc_auc,
                    'fpr': fpr,
                    'tpr': tpr,
                    'pr_auc': pr_auc,
                    'precision': precision,
                    'recall': recall,
                    'y_true': y_data,
                    'y_pred': y_pred,
                    'y_proba': y_proba
                }
            except Exception as e:
                error_msg = f"Error evaluating {name} on {data_type} data: {str(e)}"
                self._error_messages.append(error_msg)
                print(error_msg)

    def plot_eeg_comparison(self, healthy_data, patient_data, healthy_sfreq=250, patient_sfreq=100, samples=500, channels=3):
        """Enhanced EEG signal comparison"""
        plt.style.use('seaborn-v0_8-whitegrid')
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8), sharey=True)
        
        time_h = np.arange(min(samples, len(healthy_data))) / healthy_sfreq
        time_p = np.arange(min(samples, len(patient_data))) / patient_sfreq
        
        # Healthy EEG
        for ch in range(min(channels, healthy_data.shape[1])):
            ax1.plot(time_h, healthy_data[:len(time_h), ch] * 1e6,
                    linewidth=1.5, alpha=0.8, label=f'Ch{ch+1}')
        
        ax1.set_title('Healthy EEG', fontsize=16, pad=20)
        ax1.set_xlabel('Time (s)', fontsize=14)
        ax1.set_ylabel('Amplitude (μV)', fontsize=14)
        ax1.legend(loc='upper right', fontsize=10)
        ax1.grid(True, linestyle=':', alpha=0.5)
        ax1.set_ylim(-150, 150)
        
        # Patient EEG
        for ch in range(min(channels, patient_data.shape[1])):
            ax2.plot(time_p, patient_data[:len(time_p), ch] * 1e6,
                    linewidth=1.5, alpha=0.8, label=f'Ch{ch+1}')
        
        ax2.set_title('Patient EEG', fontsize=16, pad=20)
        ax2.set_xlabel('Time (s)', fontsize=14)
        ax2.legend(loc='upper right', fontsize=10)
        ax2.grid(True, linestyle=':', alpha=0.5)
        ax2.set_ylim(-150, 150)
        
        plt.suptitle('EEG Signal Comparison', fontsize=18, y=1.02)
        plt.tight_layout()
        plt.show()

    def plot_roc_curves(self):
        """Enhanced ROC curve plotting with AUC values"""
        if not self.results or 'Patient' not in self.results:
            print("No patient results available for ROC curves")
            return
            
        plt.figure(figsize=(10, 8))
        
        for model_name, results in self.results['Patient'].items():
            if results.get('roc_auc') is not None:
                plt.plot(results['fpr'], results['tpr'],
                        lw=2, label=f'{model_name} (AUC = {results["roc_auc"]:.2f})')
        
        plt.plot([0, 1], [0, 1], 'k--', lw=1)
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate', fontsize=12)
        plt.ylabel('True Positive Rate', fontsize=12)
        plt.title('Receiver Operating Characteristic', fontsize=16)
        plt.legend(loc="lower right", fontsize=10)
        plt.grid(True, alpha=0.3)
        plt.show()

    def plot_precision_recall_curves(self):
        """Plot precision-recall curves"""
        if not self.results or 'Patient' not in self.results:
            print("No patient results available for PR curves")
            return
            
        plt.figure(figsize=(10, 8))
        
        for model_name, results in self.results['Patient'].items():
            if results.get('pr_auc') is not None:
                plt.plot(results['recall'], results['precision'],
                        lw=2, label=f'{model_name} (AP = {results["pr_auc"]:.2f})')
        
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('Recall', fontsize=12)
        plt.ylabel('Precision', fontsize=12)
        plt.title('Precision-Recall Curve', fontsize=16)
        plt.legend(loc="upper right", fontsize=10)
        plt.grid(True, alpha=0.3)
        plt.show()

    def plot_feature_importance(self):
        """Plot feature importance if available"""
        for model_name, model in self.models.items():
            if hasattr(model, 'feature_importances_'):
                importances = model.feature_importances_
                indices = np.argsort(importances)[::-1]
                
                plt.figure(figsize=(12, 6))
                plt.title(f"Feature Importances - {model_name}", fontsize=16)
                plt.bar(range(20), importances[indices[:20]], align="center")
                plt.xticks(range(20), indices[:20], rotation=45)
                plt.xlim([-1, 20])
                plt.ylabel("Importance", fontsize=12)
                plt.xlabel("Feature Index", fontsize=12)
                plt.tight_layout()
                plt.show()

    def run_pipeline(self, config):
        """Enhanced pipeline with better error handling"""
        results = None
        try:
            print("\n=== EEG Analysis Pipeline ===\n")
            print("\n=== Data Validation ===")
            print("Healthy files:")
            healthy_files = [f for f in os.listdir(config['healthy_path']) if f.endswith('.vhdr')]
            print(f"Found {len(healthy_files)} files")
            print("Sample healthy file:", healthy_files[0] if healthy_files else "None")
            
            print("\nPatient files:")
            patient_files = [f for f in os.listdir(config['patient_path']) if f.endswith('.mat')]
            print(f"Found {len(patient_files)} files")
            print("Sample patient file:", patient_files[0] if patient_files else "None")
            # 1. Load data
            print("[1/7] Loading data...")
            healthy_data, healthy_labels, h_sfreq = self.load_dataset(
                config['healthy_path'],
                max_files=config.get('max_files'),
                max_duration=config.get('max_duration', None))
            patient_data, patient_labels, p_sfreq = self.load_dataset(
                config['patient_path'],
                is_patient=True,
                max_files=config.get('max_files'),
                max_duration=config.get('max_duration', None))
            
            # 2. Channel matching
            print("[2/7] Finding common channels...")
            if not config.get('skip_channel_matching', False):
                healthy_channels = self.collect_channel_names(config['healthy_path'])
                patient_channels = self.collect_channel_names(config['patient_path'], is_patient=True)
                self.common_channels = self.find_common_channels(healthy_channels, patient_channels)
            
            # 3. Preprocessing
            print("[3/7] Preprocessing data...")
            X_healthy, y_healthy = self.preprocess_data(
                healthy_data, healthy_labels,
                sfreq=h_sfreq,
                common_channels=self.common_channels)
            X_patients, y_patients = self.preprocess_data(
                patient_data, patient_labels,
                sfreq=p_sfreq,
                common_channels=self.common_channels)
            
            # 4. Load artifacts
            print("[4/7] Loading preprocessing artifacts...")
            self.load_preprocessing_artifacts(
                config['scaler_path'],
                config['label_encoder_path'])
            
            # 5. Apply transformations
            # 5. Apply transformations
            print("[5/7] Applying transformations...")
            if self.scaler:
                # Validate data before transformation
                if len(X_healthy) == 0 or len(X_patients) == 0:
                    raise ValueError("Empty feature arrays - check preprocessing output")
                
                try:
                    X_healthy = self.scaler.transform(X_healthy)
                    X_patients = self.scaler.transform(X_patients)
                except Exception as e:
                    error_msg = f"Scaling failed: {str(e)}"
                    self._error_messages.append(error_msg)
                    raise ValueError(error_msg)
            
            if self.label_encoder:
                try:
                    y_patients = self.label_encoder.transform(y_patients)
                except Exception as e:
                    error_msg = f"Label encoding failed: {str(e)}"
                    self._error_messages.append(error_msg)
                    raise ValueError(error_msg)
            
            # 6. Load models
            print("[6/7] Loading models...")
            self.load_models(config['model_paths'])
            
            # 7. Evaluate models
            print("[7/7] Evaluating models...")
            self.evaluate_models(X_patients, y_patients, data_type='Patient')
            
            # Generate visualizations
            print("\n=== Generating Visualizations ===\n")
            try:
                plt.style.use('seaborn-v0_8-whitegrid')
            except:
                plt.style.use('ggplot')
            
            vis_funcs = [
                ('EEG Comparison', self.plot_eeg_comparison, [healthy_data, patient_data, h_sfreq, p_sfreq]),
                ('Patient Response', self.plot_patient_response_categories, []),
                ('Confusion Matrix', self.plot_confusion_matrices, ['Patient']),
                ('Model Performance', self.plot_model_performance_comparison, []),
                ('ROC Curves', self.plot_roc_curves, []),
                ('PR Curves', self.plot_precision_recall_curves, []),
                ('Feature Importance', self.plot_feature_importance, [])
            ]
            
            for name, func, args in vis_funcs:
                try:
                    print(f"Generating {name} visualization...")
                    func(*args)
                    plt.close('all')
                except Exception as e:
                    error_msg = f"Failed to generate {name}: {str(e)}"
                    self._error_messages.append(error_msg)
                    print(error_msg)
            
            print("\n=== Pipeline Completed ===\n")
            if self._error_messages:
                print("Completed with warnings/errors (see above messages)")
            else:
                print("Completed successfully!")
            
            results = {
                'healthy_data': (X_healthy, y_healthy),
                'patient_data': (X_patients, y_patients),
                'results': self.results,
                'metadata': {
                    'healthy_samples': len(X_healthy),
                    'patient_samples': len(X_patients),
                    'common_channels': self.common_channels,
                    'features_per_channel': self.expected_features_per_channel,
                    'errors': self._error_messages
                }
            }
            
        except Exception as e:
            print(f"\n!!! Pipeline Failed !!!\nError: {str(e)}")
            import traceback
            traceback.print_exc()
            
        finally:
            if 'healthy_data' in locals():
                del healthy_data
            if 'patient_data' in locals():
                del patient_data
            gc.collect()
            
        return results


if __name__ == "__main__":
    channel_mapping = {
        'FP1': 'Fp1',
        'FP2': 'Fp2',
        'F3': 'F3',
        'F4': 'F4',
    }

    config = {
        'healthy_path': "F:/shivani/VSCode/ml/worked on dataset/3(final)/Healthy",
        'patient_path': "F:/shivani/VSCode/ml/worked on dataset/4/dataset/Patients",
        'model_paths': {
            "SVM": "F:/shivani/VSCode/ml/worked on dataset/3(final)/svm_model.pkl",
            "RF": "F:/shivani/VSCode/ml/worked on dataset/3(final)/rf_model.pkl",
            "XGB": "F:/shivani/VSCode/ml/worked on dataset/3(final)/xgb_model.pkl"
        },
        'scaler_path': "F:/shivani/VSCode/ml/worked on dataset/3(final)/preprocessing_artifacts/eeg_scaler.joblib",
        'label_encoder_path': "F:/shivani/VSCode/ml/worked on dataset/3(final)/preprocessing_artifacts/label_encoder.joblib",
        'max_files': 5,
    }

    pipeline = EEGPipeline()
    pipeline.set_channel_mapping(channel_mapping)
    results = pipeline.run_pipeline(config)


=== EEG Analysis Pipeline ===


=== Data Validation ===
Healthy files:
Found 40 files
Sample healthy file: sub-010002 - Copy.vhdr

Patient files:
Found 8 files
Sample patient file: A01.mat
[1/7] Loading data...


Loading healthy files:   0%|                                                                     | 0/5 [00:00<?, ?it/s]

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 45 Hz

FIR filter parameters
---------------------
Designing a two-pass forward and reverse, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-12 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 45.00 Hz
- Upper transition bandwidth: 11.25 Hz (-12 dB cutoff frequency: 50.62 Hz)
- Filter length: 16501 samples (6.600 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    5.6s


Filtering raw data in 1 contiguous segment
Setting up band-stop filter

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower transition bandwidth: 0.50 Hz
- Upper transition bandwidth: 0.50 Hz
- Filter length: 16501 samples (6.600 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    3.0s


Used Annotations descriptions: ['Comment/no USB Connection to actiCAP', 'New Segment/', 'Stimulus/S  1', 'Stimulus/S200', 'Stimulus/S210']


Loading healthy files:  20%|████████████▏                                                | 1/5 [00:37<02:30, 37.55s/it]

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 45 Hz

FIR filter parameters
---------------------
Designing a two-pass forward and reverse, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-12 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 45.00 Hz
- Upper transition bandwidth: 11.25 Hz (-12 dB cutoff frequency: 50.62 Hz)
- Filter length: 16501 samples (6.600 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    5.6s


Filtering raw data in 1 contiguous segment
Setting up band-stop filter

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower transition bandwidth: 0.50 Hz
- Upper transition bandwidth: 0.50 Hz
- Filter length: 16501 samples (6.600 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    3.0s


Used Annotations descriptions: ['Comment/no USB Connection to actiCAP', 'New Segment/', 'Stimulus/S  1', 'Stimulus/S200', 'Stimulus/S210']


Loading healthy files:  40%|████████████████████████▍                                    | 2/5 [01:14<01:51, 37.29s/it]

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 45 Hz

FIR filter parameters
---------------------
Designing a two-pass forward and reverse, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-12 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 45.00 Hz
- Upper transition bandwidth: 11.25 Hz (-12 dB cutoff frequency: 50.62 Hz)
- Filter length: 16501 samples (6.600 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    5.6s


Filtering raw data in 1 contiguous segment
Setting up band-stop filter

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower transition bandwidth: 0.50 Hz
- Upper transition bandwidth: 0.50 Hz
- Filter length: 16501 samples (6.600 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    2.9s


Used Annotations descriptions: ['Comment/no USB Connection to actiCAP', 'New Segment/', 'Stimulus/S  1', 'Stimulus/S200', 'Stimulus/S210']


Loading healthy files:  60%|████████████████████████████████████▌                        | 3/5 [01:51<01:14, 37.17s/it]

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 45 Hz

FIR filter parameters
---------------------
Designing a two-pass forward and reverse, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-12 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 45.00 Hz
- Upper transition bandwidth: 11.25 Hz (-12 dB cutoff frequency: 50.62 Hz)
- Filter length: 16501 samples (6.600 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    5.4s


Filtering raw data in 1 contiguous segment
Setting up band-stop filter

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower transition bandwidth: 0.50 Hz
- Upper transition bandwidth: 0.50 Hz
- Filter length: 16501 samples (6.600 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    3.0s


Used Annotations descriptions: ['Comment/no USB Connection to actiCAP', 'New Segment/', 'Stimulus/S  1', 'Stimulus/S200', 'Stimulus/S210']


Loading healthy files:  80%|████████████████████████████████████████████████▊            | 4/5 [02:28<00:37, 37.11s/it]

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 45 Hz

FIR filter parameters
---------------------
Designing a two-pass forward and reverse, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-12 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 45.00 Hz
- Upper transition bandwidth: 11.25 Hz (-12 dB cutoff frequency: 50.62 Hz)
- Filter length: 16501 samples (6.600 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    5.5s


Filtering raw data in 1 contiguous segment
Setting up band-stop filter

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower transition bandwidth: 0.50 Hz
- Upper transition bandwidth: 0.50 Hz
- Filter length: 16501 samples (6.600 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    2.8s


Used Annotations descriptions: ['Comment/actiCAP Data On', 'New Segment/', 'Stimulus/S  1', 'Stimulus/S200', 'Stimulus/S210']


Loading healthy files: 100%|█████████████████████████████████████████████████████████████| 5/5 [03:05<00:00, 37.11s/it]



Successfully loaded 5/5 files


Loading patient files: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:02<00:00,  2.50it/s]



Successfully loaded 5/5 files
[2/7] Finding common channels...


  raw = mne.io.read_raw_brainvision(vhdr_path, preload=False, verbose=False)
Collecting healthy channels:  50%|██████████████████████████▌                          | 20/40 [00:01<00:00, 20.86it/s]

Error loading F:/shivani/VSCode/ml/worked on dataset/3(final)/Healthy\sub-010020.vhdr: [Errno 2] No such file or directory: 'F:\\shivani\\VSCode\\ml\\worked on dataset\\3(final)\\Healthy\\Untitled.vmrk'


Collecting healthy channels: 100%|█████████████████████████████████████████████████████| 40/40 [00:02<00:00, 17.47it/s]


Error loading F:/shivani/VSCode/ml/worked on dataset/3(final)/Healthy\sub-010044.vhdr: [Errno 2] No such file or directory: 'F:\\shivani\\VSCode\\ml\\worked on dataset\\3(final)\\Healthy\\sub-010044eeg.vmrk'


Collecting patient channels: 100%|███████████████████████████████████████████████████████| 8/8 [00:02<00:00,  3.60it/s]


Found 8 common channels
[3/7] Preprocessing data...

=== Starting Preprocessing ===
Input data shape: (12704800, 62)
Sample rate: 2500.0 Hz
Number of channels: 62
First 3 channel names: ['CZ', 'FZ', 'OZ']
Created info structure
Preprocessing failed: Failed to create RawArray: len(data) (62) does not match len(info["ch_names"]) (8)

=== Starting Preprocessing ===
Input data shape: (1738520, 8)
Sample rate: 100.0 Hz
Number of channels: 8
First 3 channel names: ['CZ', 'FZ', 'OZ']
Created info structure
Creating RawArray with float64 data, n_channels=8, n_times=1738520
    Range : 0 ... 1738519 =      0.000 ... 17385.190 secs
Ready.
Successfully created RawArray
Nyquist frequency: 50.0 Hz
Applying bandpass filter...
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 45 Hz

FIR filter parameters
---------------------
Designing a two-pass forward and reverse, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window 



Extracted features shape: (17385, 136)
[4/7] Loading preprocessing artifacts...
✓ Scaler loaded successfully (expecting 992 features)
✓ Label encoder loaded successfully
[5/7] Applying transformations...

!!! Pipeline Failed !!!
Error: Empty feature arrays - check preprocessing output


Traceback (most recent call last):
  File "C:\Users\shivani\AppData\Local\Temp\ipykernel_7468\1448636875.py", line 687, in run_pipeline
    raise ValueError("Empty feature arrays - check preprocessing output")
ValueError: Empty feature arrays - check preprocessing output


In [None]:
# import os
# import numpy as np
# from tqdm import tqdm
# import mne
# import gc
# from scipy.io import loadmat
# from sklearn.preprocessing import StandardScaler, LabelEncoder
# from sklearn.metrics import (
#     accuracy_score, classification_report, confusion_matrix,
#     roc_curve, auc, precision_recall_curve, average_precision_score
# )
# import matplotlib.pyplot as plt
# import seaborn as sns
# import joblib
# from scipy.signal import welch
# import pywt
# import pandas as pd
# from matplotlib.gridspec import GridSpec
# import re
# from sklearn.model_selection import train_test_split
# from matplotlib.ticker import FormatStrFormatter
# from keras.models import load_model

# HAS_ENTROPY_PACKAGE = False
# try:
#     from entropy import sample_entropy as entropy_sample_entropy, spectral_entropy as entropy_spectral_entropy
#     HAS_ENTROPY_PACKAGE = True
# except ImportError:
#     print("Entropy package not found - using custom implementations")
#     HAS_ENTROPY_PACKAGE = False
    
# # Configure plotting
# plt.rcParams['figure.figsize'] = [12, 8]
# plt.rcParams['font.size'] = 12
# sns.set_style("whitegrid")
# sns.set_palette("colorblind")

# class EEGPipeline:
#     def __init__(self):
#         self.models = {}
#         self.results = {}
#         self.scaler = None
#         self.label_encoder = None
#         self.common_channels = None
#         self.feature_names = [
#             'Delta Power', 'Theta Power', 'Alpha Power', 'Beta Power', 'Gamma Power',
#             'Wavelet Mean 1', 'Wavelet Mean 2', 'Wavelet Mean 3', 'Wavelet Mean 4', 'Wavelet Mean 5',
#             'Mean', 'Std Dev', 'Median', 'Sample Entropy', 'Spectral Entropy'
#         ]
#         self.band_names = ['Delta', 'Theta', 'Alpha', 'Beta', 'Gamma']
#         self.channel_mapping = {}
#         self.expected_features_per_channel = 15  # Updated to include entropy features
#         self._error_messages = []  # Track errors during pipeline execution

#     def custom_sample_entropy(self, signal, m=2, r_factor=0.2):
#         """Custom implementation of sample entropy"""
#         n = len(signal)
#         r = r_factor * np.std(signal)
        
#         def _maxdist(x, y):
#             return np.max(np.abs(x - y))
            
#         def _phi(m):
#             x = np.array([signal[i:i+m] for i in range(n - m + 1)])
#             C = np.zeros(len(x))
#             for i in range(len(x)):
#                 for j in range(len(x)):
#                     if i != j and _maxdist(x[i], x[j]) <= r:
#                         C[i] += 1
#             return np.sum(C) / (len(x) * (len(x) - 1))
            
#         if n == 0 or m > n:
#             return 0.0
            
#         return -np.log(_phi(m+1) / _phi(m)) if _phi(m) != 0 else 0.0

#     def custom_spectral_entropy(self, signal, sfreq, bands=None):
#         """Custom implementation of spectral entropy"""
#         if bands is None:
#             bands = {
#                 'delta': (0.5, 4),
#                 'theta': (4, 8),
#                 'alpha': (8, 12),
#                 'beta': (12, 30),
#                 'gamma': (30, 50)
#             }
        
#         freqs, psd = welch(signal, fs=sfreq, nperseg=min(256, len(signal)))
#         total_power = np.sum(psd)
        
#         if total_power == 0:
#             return 0.0
            
#         prob = psd / total_power
#         prob = prob[prob > 0]  # Avoid log(0)
        
#         # Calculate entropy for each band
#         band_entropies = []
#         for low, high in bands.values():
#             band_mask = (freqs >= low) & (freqs <= high)
#             band_prob = prob[band_mask]
#             if len(band_prob) > 0:
#                 band_entropy = -np.sum(band_prob * np.log(band_prob))
#                 band_entropies.append(band_entropy)
        
#         return np.sum(band_entropies) if band_entropies else 0.0

#     def calculate_sample_entropy(self, signal, m=2, r_factor=0.2):
#         """Calculate sample entropy using available implementation"""
#         if HAS_ENTROPY_PACKAGE:
#             return entropy_sample_entropy(signal, m=m, r=r_factor*np.std(signal))
#         return self.custom_sample_entropy(signal, m=m, r_factor=r_factor)

#     def calculate_spectral_entropy(self, signal, sfreq):
#         """Calculate spectral entropy using available implementation"""
#         bands = {'delta': (0.5, 4), 'theta': (4, 8), 'alpha': (8, 12),
#                 'beta': (12, 30), 'gamma': (30, 50)}
#         if HAS_ENTROPY_PACKAGE:
#             return entropy_spectral_entropy(signal, sfreq, bands)
#         return self.custom_spectral_entropy(signal, sfreq, bands)

#     def set_channel_mapping(self, mapping_dict):
#         """Set manual channel name mapping between different naming conventions"""
#         self.channel_mapping = mapping_dict

#     def normalize_channel_name(self, channel_name):
#         """Advanced channel name normalization with manual mapping support"""
#         if isinstance(channel_name, (list, np.ndarray)):
#             channel_name = channel_name[0]
#         channel_name = str(channel_name).strip().upper()
#         if channel_name in self.channel_mapping:
#             return self.channel_mapping[channel_name]
#         channel_name = re.sub(r'[^A-Z0-9]', '', channel_name)
#         channel_name = re.sub(r'^CH', '', channel_name)
#         channel_name = re.sub(r'^EEG', '', channel_name)
#         channel_name = channel_name.lstrip('0')
#         # Handle common variations
#         variations = {
#             'FP1': 'Fp1', 'FP2': 'Fp2',
#             'T3': 'T7', 'T4': 'T8',
#             'T5': 'P7', 'T6': 'P8'
#         }
#         return variations.get(channel_name, channel_name)

#     def get_channel_names_from_mat(self, mat_path):
#         """Robust MAT file channel extraction supporting multiple formats"""
#         try:
#             mat_data = loadmat(mat_path)
#             channels = []
#             # Structure 1: Nested 'data' structure
#             if 'data' in mat_data:
#                 data_struct = mat_data['data'][0][0]
#                 if 'channels' in data_struct.dtype.names:
#                     channels = [str(ch[0]) for ch in data_struct['channels'][0]]
#                 elif 'chanlocs' in data_struct.dtype.names:
#                     chanlocs = data_struct['chanlocs'][0]
#                     channels = [str(chan['labels'][0]) for chan in chanlocs]
#             # Structure 2: EEGLAB structure
#             elif 'EEG' in mat_data:
#                 eeg_struct = mat_data['EEG'][0][0]
#                 if 'chanlocs' in eeg_struct.dtype.names:
#                     chanlocs = eeg_struct['chanlocs'][0]
#                     channels = [str(chan['labels'][0]) for chan in chanlocs]
#                 elif 'chaninfo' in eeg_struct.dtype.names:
#                     chaninfo = eeg_struct['chaninfo'][0][0]
#                     if 'labels' in chaninfo.dtype.names:
#                         channels = [str(ch[0]) for ch in chaninfo['labels'][0]]
#             # Structure 3: Simple X,y structure
#             elif 'X' in mat_data and 'ch_names' in mat_data:
#                 channels = [str(ch[0]) for ch in mat_data['ch_names'][0]]
#             return [self.normalize_channel_name(ch) for ch in channels if ch and str(ch).strip()]
#         except Exception as e:
#             error_msg = f"Error loading {mat_path}: {str(e)}"
#             self._error_messages.append(error_msg)
#             print(error_msg)
#             return []

#     def get_channel_names_from_vhdr(self, vhdr_path):
#         """Extract channel names from BrainVision files with validation"""
#         try:
#             raw = mne.io.read_raw_brainvision(vhdr_path, preload=False, verbose=False)
#             return [self.normalize_channel_name(ch) for ch in raw.ch_names]
#         except Exception as e:
#             error_msg = f"Error loading {vhdr_path}: {str(e)}"
#             self._error_messages.append(error_msg)
#             print(error_msg)
#             return []

#     def find_common_channels(self, healthy_channels, patient_channels):
#         """Flexible channel matching with multiple strategies"""
#         # First try exact matching
#         common = set(healthy_channels).intersection(patient_channels)
#         if not common:
#             healthy_set = set(healthy_channels)
#             patient_set = set(patient_channels)
#             common = healthy_set.intersection(patient_set)
#         if not common:
#             common_partial = set()
#             for h_ch in healthy_set:
#                 for p_ch in patient_set:
#                     if h_ch in p_ch or p_ch in h_ch:
#                         common_partial.add(h_ch)
#             if common_partial:
#                 print(f"Using partial channel matches: {common_partial}")
#                 return sorted(common_partial)
#         print(f"Found {len(common)} common channels")
#         return sorted(common)

#     def collect_channel_names(self, folder_path, is_patient=False):
#         """Collect channel names with extensive validation"""
#         all_channels = set()
#         files = [f for f in os.listdir(folder_path) if f.endswith('.vhdr' if not is_patient else '.mat')]
#         if not files:
#             raise FileNotFoundError(f"No valid files found in {folder_path}")
#         for file in tqdm(files, desc=f"Collecting {'patient' if is_patient else 'healthy'} channels"):
#             file_path = os.path.join(folder_path, file)
#             channels = self.get_channel_names_from_mat(file_path) if is_patient else self.get_channel_names_from_vhdr(file_path)
#             if not channels:
#                 print(f"Warning: No channels found in {file}")
#                 continue
#             all_channels.update(channels)
#         if not all_channels:
#             raise ValueError(f"No channels collected from {folder_path}")
#         return sorted(all_channels)

#     def load_dataset(self, folder_path, is_patient=False, max_files=None, max_duration=None):
#         """Robust dataset loading with comprehensive validation"""
#         files = [f for f in os.listdir(folder_path) if f.endswith('.mat' if is_patient else '.vhdr')]
#         loader_func = self.load_mat_data if is_patient else self.load_brainvision_data
#         if not files:
#             raise FileNotFoundError(f"No valid files found in {folder_path}")
#         if max_files:
#             files = files[:max_files]
#         all_data = []
#         all_labels = []
#         sfreqs = []
#         loaded_files = 0
#         for file in tqdm(files, desc=f"Loading {'patient' if is_patient else 'healthy'} files"):
#             file_path = os.path.join(folder_path, file)
#             try:
#                 data, labels, sfreq = loader_func(file_path, max_duration)
#                 if len(data) > 0 and len(labels) > 0:
#                     all_data.append(data)
#                     all_labels.append(labels)
#                     sfreqs.append(sfreq)
#                     loaded_files += 1
#                 else:
#                     print(f"Skipping {file} - no valid data")
#             except Exception as e:
#                 error_msg = f"Error loading {file}: {str(e)}"
#                 self._error_messages.append(error_msg)
#                 print(error_msg)
#             gc.collect()
#         print(f"\nSuccessfully loaded {loaded_files}/{len(files)} files")
#         if not all_data:
#             raise ValueError("No valid data loaded - check file formats")
#         avg_sfreq = np.mean(sfreqs) if sfreqs else (100 if is_patient else 250)
#         return np.concatenate(all_data), np.concatenate(all_labels), avg_sfreq

#     def load_brainvision_data(self, vhdr_path, max_duration=None):
#         """Load BrainVision data with enhanced validation"""
#         try:
#             raw = mne.io.read_raw_brainvision(vhdr_path, preload=True, verbose=False)
#             if max_duration:
#                 crop_end = min(max_duration, raw.times[-1])
#                 raw.crop(tmax=crop_end)
#             data = raw.get_data().T.astype(np.float32)
#             events, _ = mne.events_from_annotations(raw)
#             labels = events[:, 2] if len(events) > 0 else np.zeros(len(data))
#             return data, labels, raw.info['sfreq']
#         except Exception as e:
#             error_msg = f"Error loading {vhdr_path}: {str(e)}"
#             self._error_messages.append(error_msg)
#             print(error_msg)
#             return np.array([]), np.array([]), None

#     def load_mat_data(self, mat_path, max_duration=None, default_sfreq=100):
#         """Flexible MAT file loader supporting multiple structures"""
#         try:
#             mat_data = loadmat(mat_path)
#             eeg_data = None
#             labels = None
#             sfreq = default_sfreq
#             if 'data' in mat_data:
#                 data_struct = mat_data['data'][0][0]
#                 if 'X' in data_struct.dtype.names:
#                     eeg_data = data_struct['X']
#                     if eeg_data.ndim > 2:
#                         eeg_data = eeg_data.reshape(eeg_data.shape[0], -1)
#                     if 'y' in data_struct.dtype.names:
#                         labels = data_struct['y'].flatten()
#                     if 'sfreq' in data_struct.dtype.names:
#                         sfreq = float(data_struct['sfreq'][0][0])
#                     elif 'Fs' in data_struct.dtype.names:
#                         sfreq = float(data_struct['Fs'][0][0])
#             elif 'EEG' in mat_data:
#                 eeg_struct = mat_data['EEG'][0][0]
#                 if 'data' in eeg_struct.dtype.names:
#                     eeg_data = eeg_struct['data'].T
#                 if 'event' in eeg_struct.dtype.names:
#                     events = eeg_struct['event'][0]
#                     labels = np.array([ev[0]['type'][0] for ev in events])
#                 if 'srate' in eeg_struct.dtype.names:
#                     sfreq = float(eeg_struct['srate'][0][0])
#             elif 'X' in mat_data:
#                 eeg_data = mat_data['X']
#                 if 'y' in mat_data:
#                     labels = mat_data['y'].flatten()
#             if labels is None or len(np.unique(labels)) <= 1:
#                 labels = np.zeros(len(eeg_data)) if eeg_data is not None else np.array([])
#             if eeg_data is None:
#                 raise ValueError("No EEG data found in MAT file")
#             min_len = min(len(eeg_data), len(labels))
#             eeg_data = eeg_data[:min_len]
#             labels = labels[:min_len]
#             if max_duration:
#                 max_samples = int(max_duration * sfreq)
#                 eeg_data = eeg_data[:max_samples]
#                 labels = labels[:max_samples]
#             return eeg_data.astype(np.float32), labels, sfreq
#         except Exception as e:
#             error_msg = f"Error loading {mat_path}: {str(e)}"
#             self._error_messages.append(error_msg)
#             print(error_msg)
#             return np.array([]), np.array([]), None

#     def preprocess_data(self, data, labels, sfreq=250, common_channels=None):
#         """Robust preprocessing with fixed feature dimensions including entropy features"""
#         try:
#             n_channels = data.shape[1]
#             ch_names = common_channels[:n_channels] if common_channels and len(common_channels) >= n_channels else [f'ch{i}' for i in range(n_channels)]
#             info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types='eeg')
#             raw = mne.io.RawArray(data.T, info)
            
#             # Apply bandpass filter (0.5 - 40 Hz)
#             raw.filter(0.5, 40, fir_design='firwin', phase='zero-double')
            
#             # Apply notch filter only for frequencies below Nyquist (sfreq/2)
#             notch_freqs = [f for f in [50, 60] if f < (sfreq / 2)]
#             if notch_freqs:  # Only apply if there are valid frequencies
#                 raw.notch_filter(notch_freqs)
            
#             # Create epochs
#             events = mne.make_fixed_length_events(raw, duration=1.0)
#             epochs = mne.Epochs(raw, events, tmin=0, tmax=1.0, baseline=None, preload=True)
#             epochs_data = epochs.get_data()

#             def extract_features(epoch_data):
#                 features = []
#                 for epoch in epoch_data:
#                     epoch_features = []
#                     for channel in epoch:
#                         # Traditional features
#                         freqs, psd = welch(channel, fs=sfreq, nperseg=min(256, len(channel)))
#                         bands = {'delta': (0.5, 4), 'theta': (4, 8), 'alpha': (8, 12),
#                                 'beta': (12, 30), 'gamma': (30, 50)}
#                         band_powers = [np.sum(psd[(freqs >= low) & (freqs <= high)]) 
#                                      for low, high in bands.values()]
                        
#                         # Wavelet features with adjusted level to avoid boundary effects
#                         max_level = pywt.dwt_max_level(len(channel), 'db4')
#                         level = min(4, max_level) if max_level is not None else 4
#                         coeffs = pywt.wavedec(channel, 'db4', level=level)
#                         wavelet_features = [np.mean(c) for c in coeffs[:5]]
#                         if len(wavelet_features) < 5:
#                             wavelet_features += [0.0] * (5 - len(wavelet_features))
                        
#                         # Statistical features
#                         stats = [np.mean(channel), np.std(channel), np.median(channel)]
                        
#                         # Entropy features
#                         samp_entropy = self.calculate_sample_entropy(channel)
#                         spec_entropy = self.calculate_spectral_entropy(channel, sfreq)
                        
#                         # Combine all features
#                         epoch_features.extend(band_powers + wavelet_features + stats + [samp_entropy, spec_entropy])
#                     features.append(epoch_features)
#                 return np.array(features)

#             X = extract_features(epochs_data)
#             y = labels[:len(X)]
#             expected_features = n_channels * self.expected_features_per_channel
#             if X.shape[1] != expected_features:
#                 if X.shape[1] < expected_features:
#                     pad_width = ((0, 0), (0, expected_features - X.shape[1]))
#                     X = np.pad(X, pad_width, mode='constant')
#                 else:
#                     X = X[:, :expected_features]
#             return X, y
#         except Exception as e:
#             error_msg = f"Error during preprocessing: {e}"
#             self._error_messages.append(error_msg)
#             print(error_msg)
#             return None, None

#     def load_preprocessing_artifacts(self, scaler_path, label_encoder_path):
#         """Load preprocessing artifacts with dimension validation"""
#         try:
#             self.scaler = joblib.load(scaler_path)
#             print(f"✓ Scaler loaded successfully (expecting {self.scaler.n_features_in_} features)")
#         except Exception as e:
#             error_msg = f"✗ Failed to load scaler: {str(e)}"
#             self._error_messages.append(error_msg)
#             print(error_msg)
#             self.scaler = None
#         try:
#             self.label_encoder = joblib.load(label_encoder_path)
#             print("✓ Label encoder loaded successfully")
#         except Exception as e:
#             error_msg = f"✗ Failed to load label encoder: {str(e)}"
#             self._error_messages.append(error_msg)
#             print(error_msg)
#             self.label_encoder = None

#     def match_features(self, X, expected_features):
#         """Ensure feature matrix matches expected dimensions"""
#         if X.shape[1] == expected_features:
#             return X
#         elif X.shape[1] < expected_features:
#             pad_width = ((0, 0), (0, expected_features - X.shape[1]))
#             return np.pad(X, pad_width, mode='constant')
#         else:
#             return X[:, :expected_features]

#     def load_models(self, model_paths):
#         """Load models with comprehensive validation including Keras models"""
#         self.models = {}
#         for name, path in model_paths.items():
#             try:
#                 if path.endswith('.h5') or path.endswith('.keras'):
#                     self.models[name] = load_model(path)
#                     print(f"✓ {name} (Keras) model loaded successfully")
#                 else:
#                     self.models[name] = joblib.load(path)
#                     print(f"✓ {name} model loaded successfully")
#             except Exception as e:
#                 error_msg = f"✗ Failed to load {name} model: {str(e)}"
#                 self._error_messages.append(error_msg)
#                 print(error_msg)
#         if not self.models:
#             raise ValueError("No models were loaded successfully")
#         return self.models

#     def evaluate_models(self, X_data, y_data, data_type='Patient'):
#         """Updated evaluation method with multiclass support and Keras model handling"""
#         self.results[data_type] = {}
#         for name, model in self.models.items():
#             try:
#                 # Handle Keras models differently
#                 if hasattr(model, 'predict'):
#                     # For Keras models, we need to reshape the data if needed
#                     if len(X_data.shape) == 2:  # If it's 2D, reshape for LSTM/CNN
#                         if 'LSTM' in name or 'CNN' in name:
#                             X_reshaped = X_data.reshape(X_data.shape[0], X_data.shape[1], 1)
#                             y_pred = model.predict(X_reshaped)
#                             if y_pred.shape[1] > 1:  # Multiclass
#                                 y_pred = np.argmax(y_pred, axis=1)
#                             else:  # Binary
#                                 y_pred = (y_pred > 0.5).astype(int)
#                         else:
#                             y_pred = model.predict(X_data)
#                             if y_pred.shape[1] > 1:
#                                 y_pred = np.argmax(y_pred, axis=1)
#                             else:
#                                 y_pred = (y_pred > 0.5).astype(int)
#                     else:
#                         y_pred = model.predict(X_data)
#                         if y_pred.shape[1] > 1:
#                             y_pred = np.argmax(y_pred, axis=1)
#                         else:
#                             y_pred = (y_pred > 0.5).astype(int)
#                 else:
#                     y_pred = model.predict(X_data)
                
#                 if len(np.unique(y_data)) > 2:
#                     acc = accuracy_score(y_data, y_pred)
#                     report = classification_report(y_data, y_pred, output_dict=True, zero_division=0)
#                     cm = confusion_matrix(y_data, y_pred)
#                     roc_auc = None
#                 else:
#                     if hasattr(model, 'predict_proba'):
#                         y_proba = model.predict_proba(X_data)[:, 1]
#                     elif hasattr(model, 'predict'):
#                         if len(X_data.shape) == 2 and ('LSTM' in name or 'CNN' in name):
#                             X_reshaped = X_data.reshape(X_data.shape[0], X_data.shape[1], 1)
#                             y_proba = model.predict(X_reshaped).flatten()
#                         else:
#                             y_proba = model.predict(X_data).flatten()
#                     else:
#                         y_proba = None
                    
#                     acc = accuracy_score(y_data, y_pred)
#                     report = classification_report(y_data, y_pred, output_dict=True, zero_division=0)
#                     cm = confusion_matrix(y_data, y_pred)
#                     if y_proba is not None:
#                         fpr, tpr, _ = roc_curve(y_data, y_proba)
#                         roc_auc = auc(fpr, tpr)
#                     else:
#                         roc_auc = None
                
#                 self.results[data_type][name] = {
#                     'accuracy': acc,
#                     'report': report,
#                     'confusion_matrix': cm,
#                     'roc_auc': roc_auc,
#                     'fpr': fpr if 'fpr' in locals() else None,
#                     'tpr': tpr if 'tpr' in locals() else None,
#                     'y_true': y_data,
#                     'y_pred': y_pred
#                 }
#             except Exception as e:
#                 error_msg = f"Error evaluating {name} on {data_type} data: {str(e)}"
#                 self._error_messages.append(error_msg)
#                 print(error_msg)

#     def plot_eeg_comparison(self, healthy_data, patient_data, healthy_sfreq=250, patient_sfreq=100, samples=500, channels=3):
#         """Horizontal EEG signal comparison plot with side-by-side channels"""
#         plt.style.use('seaborn-v0_8-whitegrid')
        
#         # Create figure with subplots for each channel
#         fig, axes = plt.subplots(channels, 2, figsize=(20, channels*3), sharey=True)
        
#         # Time vectors
#         time_h = np.arange(min(samples, len(healthy_data))) / healthy_sfreq
#         time_p = np.arange(min(samples, len(patient_data))) / patient_sfreq
        
#         # Plot each channel horizontally
#         for ch in range(min(channels, healthy_data.shape[1], patient_data.shape[1])):
#             # Healthy EEG
#             axes[ch, 0].plot(time_h, healthy_data[:len(time_h), ch] * 1e6,
#                             linewidth=1.5, alpha=0.8, color='blue')
#             axes[ch, 0].set_title(f'Healthy - Channel {ch+1}', fontsize=12)
#             axes[ch, 0].grid(True, linestyle='--', alpha=0.6)
#             axes[ch, 0].set_ylim(-100, 100)
            
#             # Patient EEG
#             axes[ch, 1].plot(time_p, patient_data[:len(time_p), ch] * 1e6,
#                             linewidth=1.5, alpha=0.8, color='red')
#             axes[ch, 1].set_title(f'Patient - Channel {ch+1}', fontsize=12)
#             axes[ch, 1].grid(True, linestyle='--', alpha=0.6)
#             axes[ch, 1].set_ylim(-100, 100)
        
#         # Set common labels
#         for ax in axes[:, 0]:
#             ax.set_ylabel('Amplitude (μV)', fontsize=12)
        
#         for ax in axes[-1, :]:
#             ax.set_xlabel('Time (seconds)', fontsize=12)
        
#         plt.suptitle('EEG Signal Comparison (First 500 Samples)', fontsize=16, y=1.02)
#         plt.tight_layout()
#         plt.show()

#     def plot_patient_response_categories(self):
#         """Enhanced patient response categorization visualization with robust error handling"""
#         if not self.results or 'Patient' not in self.results:
#             print("No patient results available for visualization")
#             return
#         try:
#             # Use the first available model's results
#             model_name = next(iter(self.results['Patient']))
#             results = self.results['Patient'][model_name]
#             if 'y_true' not in results or 'y_pred' not in results:
#                 print("Missing required data in results")
#                 return
#             y_true = results['y_true']
#             y_pred = results['y_pred']
#             # Calculate accuracy for each sample
#             if len(np.unique(y_true)) == 2:
#                 accuracies = (y_pred == y_true).astype(float)
#             else:
#                 accuracies = np.array([1.0 if pred == true else 0.0 
#                                      for pred, true in zip(y_pred, y_true)])
#             # Categorize patients with safe division
#             categories = []
#             for acc in accuracies:
#                 if acc >= 0.7:
#                     categories.append("Good Response")
#                 elif 0.4 <= acc < 0.7:
#                     categories.append("Medium Response")
#                 else:
#                     categories.append("Poor Response")
#             category_counts = pd.Series(categories).value_counts()
#             # Create figure with proper layout
#             fig = plt.figure(figsize=(18, 8))
#             gs = GridSpec(1, 2, width_ratios=[1, 1.5])
#             # Subplot 1: Pie chart with safe explode parameter
#             ax1 = fig.add_subplot(gs[0])
#             colors = ['#4CAF50', '#FFC107', '#F44336']
#             # Ensure explode matches the number of categories
#             explode = (0.05, 0.05, 0.05)[:len(category_counts)]
#             # Handle case where we might have fewer than 3 categories
#             if len(category_counts) < 3:
#                 colors = colors[:len(category_counts)]
#                 explode = explode[:len(category_counts)]
#             wedges, texts, autotexts = ax1.pie(
#                 category_counts, 
#                 labels=category_counts.index, 
#                 autopct=lambda p: f'{p:.1f}%' if p > 0 else '',
#                 startangle=90, 
#                 colors=colors,
#                 explode=explode,
#                 textprops={'fontsize': 12}
#             )
#             for autotext in autotexts:
#                 autotext.set_color('white')
#                 autotext.set_fontweight('bold')
#             ax1.set_title('Patient Response Distribution', fontsize=16, pad=20)
#             # Subplot 2: Bar plot
#             ax2 = fig.add_subplot(gs[1])
#             barplot = sns.barplot(
#                 x=category_counts.index, 
#                 y=category_counts.values, 
#                 ax=ax2,
#                 palette=colors,
#                 saturation=0.8
#             )
#             ax2.set_title('Patient Response Categories', fontsize=16, pad=20)
#             ax2.set_xlabel('Response Category', fontsize=14)
#             ax2.set_ylabel('Number of Patients', fontsize=14)
#             # Add count annotations
#             for p in barplot.patches:
#                 height = p.get_height()
#                 if not np.isnan(height) and height > 0:
#                     barplot.annotate(
#                         f'{int(height)}',
#                         (p.get_x() + p.get_width() / 2., height),
#                         ha='center', va='center',
#                         xytext=(0, 10),
#                         textcoords='offset points',
#                         fontsize=12,
#                         fontweight='bold'
#                     )
#             # Add overall accuracy if available
#             if 'accuracy' in results:
#                 overall_acc = results['accuracy']
#                 fig.text(
#                     0.5, -0.05,
#                     f'Model: {model_name} | Overall Accuracy: {overall_acc:.1%}',
#                     ha='center', va='center', fontsize=14
#                 )
#             plt.suptitle('Patient Response Categorization', fontsize=18, y=1.05)
#             plt.tight_layout()
#             plt.show()
#         except Exception as e:
#             print(f"Error generating patient response visualization: {str(e)}")
#             import traceback
#             traceback.print_exc()

#     def plot_confusion_matrices(self, data_type='Patient'):
#         """Enhanced confusion matrix visualization with better formatting"""
#         if not self.results or data_type not in self.results:
#             print(f"No results available for {data_type} data")
#             return
#         try:
#             models = list(self.results[data_type].keys())
#             num_models = len(models)
#             if num_models == 0:
#                 print("No models available for visualization")
#                 return
#             # Create figure with appropriate size
#             fig, axes = plt.subplots(1, num_models, figsize=(6*num_models, 5))
#             if num_models == 1:
#                 axes = [axes]
#             for i, model_name in enumerate(models):
#                 model_results = self.results[data_type][model_name]
#                 if 'confusion_matrix' not in model_results:
#                     print(f"No confusion matrix for {model_name}")
#                     continue
#                 cm = model_results['confusion_matrix']
#                 # Get class names
#                 if self.label_encoder:
#                     classes = self.label_encoder.classes_
#                 else:
#                     # Handle binary and multiclass cases
#                     n_classes = cm.shape[0]
#                     classes = [f'Class {i}' for i in range(n_classes)]
#                     if n_classes == 2:
#                         classes = ['Negative', 'Positive']
#                 # Normalize the confusion matrix
#                 cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
#                 # Plot with annotations
#                 sns.heatmap(
#                     cm_normalized, 
#                     annot=True, 
#                     fmt='.2f',
#                     cmap='Blues',
#                     xticklabels=classes,
#                     yticklabels=classes,
#                     ax=axes[i],
#                     cbar=False,
#                     annot_kws={'fontsize': 10},
#                     vmin=0, vmax=1
#                 )
#                 axes[i].set_title(f'{model_name} Confusion Matrix', fontsize=14)
#                 axes[i].set_xlabel('Predicted Label', fontsize=12)
#                 axes[i].set_ylabel('True Label', fontsize=12)
#             plt.suptitle(f'Model Performance on {data_type} Data', fontsize=16, y=1.05)
#             plt.tight_layout()
#             plt.show()
#         except Exception as e:
#             print(f"Error generating confusion matrices: {str(e)}")
#             import traceback
#             traceback.print_exc()

#     def plot_roc_curves_comparison(self):
#         """Plot ROC curves for all models for comparison"""
#         if not self.results or 'Patient' not in self.results:
#             print("No patient results available")
#             return
#         plt.figure(figsize=(10, 8))
#         for model_name, results in self.results['Patient'].items():
#             if results.get('roc_auc') is not None:
#                 plt.plot(results['fpr'], results['tpr'],
#                         label=f'{model_name} (AUC = {results["roc_auc"]:.2f})')
#         plt.plot([0, 1], [0, 1], 'k--')
#         plt.xlim([0.0, 1.0])
#         plt.ylim([0.0, 1.05])
#         plt.xlabel('False Positive Rate')
#         plt.ylabel('True Positive Rate')
#         plt.title('Receiver Operating Characteristic Comparison')
#         plt.legend(loc="lower right")
#         plt.grid(True)
#         plt.show()

#     def plot_model_performance_comparison(self):
#         """Enhanced model performance comparison visualization"""
#         if not self.results or 'Patient' not in self.results:
#             raise ValueError("Patient results not available")
#         models = list(self.results['Patient'].keys())
#         metrics = ['accuracy', 'precision', 'recall', 'f1-score']
#         metrics_data = []
#         for model in models:
#             report = self.results['Patient'][model]['report']
#             if isinstance(report, dict) and 'accuracy' in report:
#                 if 'macro avg' in report:
#                     metrics_data.append({
#                         'Model': model,
#                         'Accuracy': report['accuracy'],
#                         'Precision': report['macro avg']['precision'],
#                         'Recall': report['macro avg']['recall'],
#                         'F1-Score': report['macro avg']['f1-score']
#                     })
#                 elif len(report.keys()) > 3:
#                     metrics_data.append({
#                         'Model': model,
#                         'Accuracy': report['accuracy'],
#                         'Precision': report['1']['precision'],
#                         'Recall': report['1']['recall'],
#                         'F1-Score': report['1']['f1-score']
#                     })
#         if not metrics_data:
#             raise ValueError("No valid metric data found")
#         df = pd.DataFrame(metrics_data)
#         df_melted = df.melt(id_vars='Model', var_name='Metric', value_name='Score')
#         plt.figure(figsize=(12, 6))
#         barplot = sns.barplot(
#             x='Model', 
#             y='Score', 
#             hue='Metric', 
#             data=df_melted,
#             palette='viridis',
#             alpha=0.8
#         )
#         plt.title('Model Performance Metrics Comparison', fontsize=16, pad=20)
#         plt.xlabel('Model', fontsize=14)
#         plt.ylabel('Score', fontsize=14)
#         plt.ylim(0, 1.1)
#         plt.legend(title='Metric', bbox_to_anchor=(1.05, 1), loc='upper left')
#         for p in barplot.patches:
#             barplot.annotate(
#                 format(p.get_height(), '.2f'),
#                 (p.get_x() + p.get_width() / 2., p.get_height()),
#                 ha='center', va='center',
#                 xytext=(0, 10),
#                 textcoords='offset points',
#                 fontsize=10
#             )
#         plt.tight_layout()
#         plt.show()

#     def run_pipeline(self, config):
#         """Complete EEG analysis pipeline with robust error handling"""
#         results = None
#         try:
#             print("\n=== EEG Analysis Pipeline ===\n")
            
#             # 1. Load data
#             print("[1/7] Loading data...")
#             healthy_data, healthy_labels, h_sfreq = self.load_dataset(
#                 config['healthy_path'],
#                 max_files=config.get('max_files'),
#                 max_duration=config.get('max_duration', None))
            
#             patient_data, patient_labels, p_sfreq = self.load_dataset(
#                 config['patient_path'],
#                 is_patient=True,
#                 max_files=config.get('max_files'),
#                 max_duration=config.get('max_duration', None))
            
#             # Verify data was loaded
#             if healthy_data is None or len(healthy_data) == 0:
#                 raise ValueError("No healthy data loaded - check file formats and paths")
#             if patient_data is None or len(patient_data) == 0:
#                 raise ValueError("No patient data loaded - check file formats and paths")
            
#             # 2. Channel matching
#             print("[2/7] Finding common channels...")
#             if not config.get('skip_channel_matching', False):
#                 healthy_channels = self.collect_channel_names(config['healthy_path'])
#                 patient_channels = self.collect_channel_names(config['patient_path'], is_patient=True)
#                 self.common_channels = self.find_common_channels(healthy_channels, patient_channels)
            
#             # 3. Preprocessing
#             print("[3/7] Preprocessing data...")
#             X_healthy, y_healthy = self.preprocess_data(
#                 healthy_data, healthy_labels,
#                 sfreq=h_sfreq,
#                 common_channels=self.common_channels)
            
#             X_patients, y_patients = self.preprocess_data(
#                 patient_data, patient_labels,
#                 sfreq=p_sfreq,
#                 common_channels=self.common_channels)
            
#             # Verify preprocessing succeeded
#             if X_healthy is None or X_patients is None:
#                 raise ValueError("Preprocessing failed - check data compatibility")
            
#             # 4. Load artifacts
#             print("[4/7] Loading preprocessing artifacts...")
#             self.load_preprocessing_artifacts(
#                 config['scaler_path'],
#                 config['label_encoder_path'])
            
#             # 5. Apply transformations
#             print("[5/7] Applying transformations...")
#             if self.scaler:
#                 expected_features = self.scaler.n_features_in_
#                 X_healthy = self.match_features(X_healthy, expected_features)
#                 X_patients = self.match_features(X_patients, expected_features)
#                 X_healthy = self.scaler.transform(X_healthy)
#                 X_patients = self.scaler.transform(X_patients)
            
#             if self.label_encoder is not None and y_patients is not None:
#                 y_patients = self.label_encoder.transform(y_patients)
            
#             # 6. Load models
#             print("[6/7] Loading models...")
#             self.load_models(config['model_paths'])
            
#             # 7. Evaluate models (only on patient data)
#             print("[7/7] Evaluating models...")
#             self.evaluate_models(X_patients, y_patients, data_type='Patient')
            
#             # Generate visualizations
#             print("\n=== Generating Visualizations ===\n")
#             try:
#                 plt.style.use('seaborn-v0_8-whitegrid')
#             except:
#                 plt.style.use('ggplot')
                
#             vis_funcs = [
#                 ('EEG Comparison', self.plot_eeg_comparison, [healthy_data, patient_data, h_sfreq, p_sfreq]),
#                 ('Patient Response', self.plot_patient_response_categories, []),
#                 ('Confusion Matrix', self.plot_confusion_matrices, ['Patient']),
#                 ('Model Performance', self.plot_model_performance_comparison, []),
#                 ('ROC Curve', self.plot_roc_curves_comparison, [])
#             ]
            
#             for name, func, args in vis_funcs:
#                 try:
#                     print(f"Generating {name} visualization...")
#                     func(*args)
#                     plt.close('all')
#                 except Exception as e:
#                     error_msg = f"Failed to generate {name}: {str(e)}"
#                     self._error_messages.append(error_msg)
#                     print(error_msg)
            
#             print("\n=== Pipeline Completed ===\n")
#             if self._error_messages:
#                 print("Completed with warnings/errors (see above messages)")
#             else:
#                 print("Completed successfully!")
                
#             results = {
#                 'healthy_data': (X_healthy, y_healthy),
#                 'patient_data': (X_patients, y_patients),
#                 'results': self.results,
#                 'metadata': {
#                     'healthy_samples': len(X_healthy),
#                     'patient_samples': len(X_patients),
#                     'common_channels': self.common_channels,
#                     'features_per_channel': self.expected_features_per_channel,
#                     'errors': self._error_messages
#                 }
#             }
            
#         except Exception as e:
#             print(f"\n!!! Pipeline Failed !!!\nError: {str(e)}")
#             import traceback
#             traceback.print_exc()
            
#         finally:
#             if 'healthy_data' in locals():
#                 del healthy_data
#             if 'patient_data' in locals():
#                 del patient_data
#             gc.collect()
            
#         return results

# if __name__ == "__main__":
#     channel_mapping = {
#         'FP1': 'Fp1',
#         'FP2': 'Fp2',
#         'F3': 'F3',
#         'F4': 'F4',
#     }
#     config = {
#         'healthy_path': "F:/shivani/VSCode/ml/worked on dataset/3(final)/Healthy",
#         'patient_path': "F:/shivani/VSCode/ml/worked on dataset/4/dataset/Patients",
#         'model_paths': {
#             "SVM": "F:/shivani/VSCode/ml/worked on dataset/3(final)/svm_model.pkl",
#             "RF": "F:/shivani/VSCode/ml/worked on dataset/3(final)/rf_model.pkl",
#             "XGB": "F:/shivani/VSCode/ml/worked on dataset/3(final)/xgb_model.pkl",
#             "LSTM": "F:/shivani/VSCode/ml/worked on dataset/3(final)/lstm_model.h5",
#             "CNN": "F:/shivani/VSCode/ml/worked on dataset/3(final)/cnn_model.h5"
#         },
#         'scaler_path': "F:/shivani/VSCode/ml/worked on dataset/3(final)/preprocessing_artifacts/eeg_scaler.joblib",
#         'label_encoder_path': "F:/shivani/VSCode/ml/worked on dataset/3(final)/preprocessing_artifacts/label_encoder.joblib",
#         'max_files': 8,
#     }
#     pipeline = EEGPipeline()
#     pipeline.set_channel_mapping(channel_mapping)
#     results = pipeline.run_pipeline(config)

In [None]:
def get_channel_names_from_mat(mat_path):
    try:
        mat_data = loadmat(mat_path)
        if 'data' in mat_data:
            data_struct = mat_data['data'][0][0]
            if 'channels' in data_struct.dtype.names:
                return [ch[0] for ch in data_struct['channels'][0]]  # Convert array of arrays to list
        return []
    except Exception as e:
        print(f"Error loading {mat_path}: {str(e)}")
        return []

In [None]:
def get_channel_names_from_vhdr(vhdr_path):
    try:
        raw = mne.io.read_raw_brainvision(vhdr_path, preload=False, verbose=False)
        return raw.ch_names
    except Exception as e:
        print(f"Error loading {vhdr_path}: {str(e)}")
        return []

In [None]:
from scipy.io import loadmat

def inspect_mat_file(mat_path):
    """Inspect the structure of a .mat file."""
    try:
        mat_data = loadmat(mat_path)
        print(f"Keys in {mat_path}: {list(mat_data.keys())}")
        for key in mat_data.keys():
            if not key.startswith('__'):
                print(f"{key}: {type(mat_data[key])}, shape: {mat_data[key].shape if hasattr(mat_data[key], 'shape') else 'N/A'}")
    except Exception as e:
        print(f"Error loading {mat_path}: {str(e)}")

# Example usage
inspect_mat_file("F:/shivani/VSCode/ml/worked on dataset/4/dataset/Patients/A01.mat")

In [None]:
def find_common_channels(healthy_channels, patient_channels):
    """Find common channels between two datasets."""
    healthy_set = set(healthy_channels)
    patient_set = set(patient_channels)
    common_channels = healthy_set.intersection(patient_set)
    return sorted(common_channels)  # Return sorted list for consistency

In [None]:
def collect_channel_names(folder_path, is_patient=False):
    """Collect all unique channel names from a folder of files."""
    all_channels = set()
    files = [f for f in os.listdir(folder_path) if f.endswith('.vhdr' if not is_patient else '.mat')]
    for file in tqdm(files, desc=f"Collecting {'patient' if is_patient else 'healthy'} channels"):
        file_path = os.path.join(folder_path, file)
        if is_patient:
            channels = get_channel_names_from_mat(file_path)
        else:
            channels = get_channel_names_from_vhdr(file_path)
        all_channels.update(channels)
    return sorted(all_channels)  # Return sorted list for consistency

In [None]:
healthy_channels = collect_channel_names(config['healthy_path'])
patient_channels = collect_channel_names(config['patient_path'], is_patient=True)

# Find common channels
common_channels = find_common_channels(healthy_channels, patient_channels)
print(f"Common channels: {common_channels}")