In [None]:
import numpy as np
# At the start of your notebook
from IPython.display import clear_output
import gc

# After heavy computations
clear_output(wait=True)
gc.collect()
import pickle
import os
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from itertools import combinations
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, roc_curve, auc
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from scipy import stats
from scipy.signal import welch

RESULT_FOLDER = "result"
MODEL_FOLDER = "model"
model_names = ['Wavenet']  # 'CNN1D', 'Wavenet', 'S4', 'Resnet'
# Do batch analysis to find the best hyperparameters
seizures = [1, 2, 3, 5, 7]
thresholds = [0.8]
smooth_windows = [80]
patientID = 'P65'  # P65
seizureID = f'{patientID}SZ'

In [None]:
# Get the data from one patient:

p66_data = pickle.load(open(f'data/{patientID}/seizure_All_combined.pkl', "rb"))

In [None]:
def extract_grey_matter_channels(matter: pd.DataFrame):
    """Extract grey matter channels from Matter file"""
    # Get grey matter channels
    selected_matter = matter[matter['MatterType'].isin(['G', 'A'])]
    grey_matter_channels = selected_matter['ChannelNumber'].values
    
    return grey_matter_channels

p66_data.matter = pd.read_csv(f'data/{patientID}/matter.csv')
all_channels = np.arange(0, p66_data.channelNumber)
grey_channel = extract_grey_matter_channels(p66_data.matter) - 1
white_channel = np.setdiff1d(all_channels, grey_channel)

seizure_data_grey = p66_data.ictal[:,:,grey_channel]
seizure_data_white = p66_data.ictal[:,:,white_channel]

In [None]:
from datasetConstruct import EDFData
p66_raw = pickle.load(open(f'data/{patientID}/seizure_SZ1.pkl', "rb"))

In [None]:
raw_grey = p66_raw.ictal[:, grey_channel]
raw_white = p66_raw.ictal[:, white_channel]

In [None]:
channel_names = p66_data.matter['ElectrodeName'].values

In [None]:
# Combine the segment data for seizure_data_grey and seizure_data_white
seizure_data_grey_new = np.concatenate(seizure_data_grey, axis=0)
seizure_data_white_new = np.concatenate(seizure_data_white, axis=0)

In [None]:
from sklearn.model_selection import (train_test_split, StratifiedKFold, GroupKFold, 
                                   cross_val_score)
from sklearn.svm import SVC
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import (accuracy_score, f1_score, precision_score, recall_score,
                           balanced_accuracy_score, roc_auc_score, confusion_matrix,
                           classification_report, roc_curve, auc)


def improved_extract_and_classify_features_complete(grey_matter_data, white_matter_data, 
                                         plot_folder='result/complete_strict_validation', 
                                         fs=250, 
                                         use_windowing=True, 
                                         n_windows_per_channel=None,
                                         window_overlap=None,
                                         min_window_length=None,
                                         validation_type='strict',
                                         max_samples_per_channel=20,
                                         balance_method='hybrid',
                                         class_weight='balanced'):
    """
    完整的grey/white matter分类系统，包含strict validation和channel-level分析
    
    Parameters:
    -----------
    grey_matter_data : np.ndarray
        Grey matter data with shape [time, channels]
    white_matter_data : np.ndarray
        White matter data with shape [time, channels]
    plot_folder : str
        Folder path to save all plots
    fs : int
        Sampling frequency in Hz
    use_windowing : bool
        Whether to use time windowing to increase sample size
    n_windows_per_channel : int, optional
        Maximum number of windows per channel
    window_overlap : float, optional
        Overlap between windows (0-1)
    min_window_length : int, optional
        Minimum window length in samples
    validation_type : str
        'strict' = channel-wise split, 'normal' = random split
    max_samples_per_channel : int
        Maximum number of samples per channel
    balance_method : str
        'downsample', 'upsample', 'hybrid'
    class_weight : str or dict
        Class weighting for classifiers
    
    Returns:
    --------
    dict
        Complete results including sample-level and channel-level analysis
    """
    
    # Create output folder
    os.makedirs(plot_folder, exist_ok=True)
    
    # Get dimensions
    grey_time, grey_channels = grey_matter_data.shape
    white_time, white_channels = white_matter_data.shape
    
    print(f"COMPLETE GREY/WHITE MATTER CLASSIFICATION SYSTEM")
    print(f"=" * 60)
    print(f"VALIDATION MODE: {validation_type}")
    print(f"Grey matter data shape: [{grey_time}, {grey_channels}]")
    print(f"White matter data shape: [{white_time}, {white_channels}]")
    print(f"Windowing approach: {use_windowing}")
    print(f"Balance method: {balance_method}")
    print(f"Class weight: {class_weight}")
    print(f"Max samples per channel: {max_samples_per_channel}")
    
    # Calculate windowing parameters
    if use_windowing:
        if min_window_length is None:
            window_duration = 0.05  # 50ms
            min_window_length = int(window_duration * fs)
        
        if window_overlap is None:
            window_step = 0.025     # 25ms step
            step_samples = int(window_step * fs)
        else:
            step_samples = int(min_window_length * (1 - window_overlap))
        
        if n_windows_per_channel is None:
            max_windows_grey = (grey_time - min_window_length) // step_samples + 1
            max_windows_white = (white_time - min_window_length) // step_samples + 1
            n_windows_per_channel = min(max_windows_grey, max_windows_white)
            
        print(f"Window parameters:")
        print(f"  Window size: {min_window_length} samples ({min_window_length/fs:.3f}s)")
        print(f"  Step size: {step_samples} samples ({step_samples/fs:.3f}s)")
        print(f"  Expected windows per channel: {n_windows_per_channel}")
    
    # Define frequency bands
    freq_bands = {
        'delta': (0.5, 4),
        'theta': (4, 8),        
        'alpha': (8, 13),       
        'beta': (13, 30),       
        'gamma': (30, 100),     
        'high_gamma': (100, min(200, fs//2))
    }
    
    def extract_half_wave_features(signal, amplitude_threshold=0.1):
        """Extract half-wave features"""
        zero_crossings = np.where(np.diff(np.signbit(signal)))[0]
        
        if len(zero_crossings) <= 1:
            return {
                'hw_count': 0,
                'hw_mean_amp': 0,
                'hw_mean_duration': 0
            }
        
        half_wave_amps = []
        half_wave_durations = []
        
        for i in range(len(zero_crossings) - 1):
            start_idx = zero_crossings[i]
            end_idx = zero_crossings[i + 1]
            duration = end_idx - start_idx
            segment = signal[start_idx:end_idx]
            
            if len(segment) > 0:
                amplitude = np.max(np.abs(segment))
                if amplitude >= amplitude_threshold:
                    half_wave_amps.append(amplitude)
                    half_wave_durations.append(duration)
        
        return {
            'hw_count': len(half_wave_amps),
            'hw_mean_amp': np.mean(half_wave_amps) if half_wave_amps else 0,
            'hw_mean_duration': np.mean(half_wave_durations) / fs if half_wave_durations else 0
        }
    
    def extract_comprehensive_features_from_window(signal):
        """Extract comprehensive features from a signal window"""
        features = {}
        
        if len(signal) < 10 or np.all(signal == 0):
            return None
        
        try:
            # Time domain statistical features
            features['mean'] = np.mean(signal)
            features['std'] = np.std(signal)
            features['median'] = np.median(signal)
            features['iqr'] = np.percentile(signal, 75) - np.percentile(signal, 25)
            
            try:
                features['skew'] = stats.skew(signal)
                features['kurtosis'] = stats.kurtosis(signal)
            except:
                features['skew'] = 0
                features['kurtosis'] = 0
                
            features['range'] = np.max(signal) - np.min(signal)
            features['rms'] = np.sqrt(np.mean(signal**2))
            features['zero_crossings'] = np.sum(np.diff(np.signbit(signal).astype(int)) != 0)
            
            # Half-wave features
            hw_features = extract_half_wave_features(signal)
            features.update(hw_features)
            
            # Line length and area features
            features['line_length'] = np.sum(np.abs(np.diff(signal)))
            features['area'] = np.sum(np.abs(signal))
            
            # Frequency domain features
            try:
                windowed_signal = signal * np.hamming(len(signal))
                from scipy.fftpack import fft
                fft_vals = fft(windowed_signal)
                fft_abs = np.abs(fft_vals[:len(signal) // 2])
                fft_abs = fft_abs / len(signal)
                freq_bins = np.fft.fftfreq(len(signal), 1 / fs)[:len(signal) // 2]
                
                # Band powers
                for band_name, (low_freq, high_freq) in freq_bands.items():
                    band_mask = (freq_bins >= low_freq) & (freq_bins <= high_freq)
                    if np.any(band_mask):
                        band_power = np.sum(fft_abs[band_mask] ** 2)
                        features[f'power_{band_name}'] = band_power
                    else:
                        features[f'power_{band_name}'] = 0
                
                # Total power
                total_power = sum([features[f'power_{band}'] for band in freq_bands.keys()])
                features['total_power'] = total_power
                
                # Spectral edge frequency (95%)
                if len(fft_abs) > 0 and total_power > 0:
                    cumulative_power = np.cumsum(fft_abs ** 2)
                    edge_95_idx = np.argmax(cumulative_power >= 0.95 * np.sum(fft_abs ** 2))
                    features['spectral_edge_freq'] = freq_bins[edge_95_idx] if edge_95_idx > 0 else freq_bins[-1]
                else:
                    features['spectral_edge_freq'] = 0
                
                # Spectral entropy
                if total_power > 0:
                    power_spectrum = fft_abs ** 2
                    pxx_norm = power_spectrum / np.sum(power_spectrum)
                    features['spectral_entropy'] = -np.sum(pxx_norm * np.log2(pxx_norm + 1e-10))
                else:
                    features['spectral_entropy'] = 0
                    
            except Exception as e:
                print(f"Warning: Spectral analysis failed: {str(e)}")
                for band_name in freq_bands.keys():
                    features[f'power_{band_name}'] = 0
                features['total_power'] = 0
                features['spectral_edge_freq'] = 0
                features['spectral_entropy'] = 0
        
        except Exception as e:
            print(f"Warning: Feature extraction failed: {str(e)}")
            return None
        
        return features
    
    def add_contact_depth_features(samples, channel_info, n_total_channels):
        """Add contact depth features"""
        enhanced_samples = []
        
        for sample, ch_id in zip(samples, channel_info):
            enhanced_sample = sample.copy()
            
            if n_total_channels > 1:
                relative_depth = ch_id / (n_total_channels - 1)
            else:
                relative_depth = 0.5
            
            enhanced_sample['contact_depth'] = relative_depth
            enhanced_sample['contact_depth_squared'] = relative_depth ** 2
            enhanced_sample['is_surface_contact'] = 1 if relative_depth < 0.2 else 0
            enhanced_sample['is_middle_contact'] = 1 if 0.3 <= relative_depth <= 0.7 else 0
            enhanced_sample['is_deep_contact'] = 1 if relative_depth > 0.8 else 0
            
            enhanced_samples.append(enhanced_sample)
        
        return enhanced_samples
    
    def create_channel_separated_samples(data, tissue_type):
        """Create samples with channel tracking"""
        samples = []
        channel_info = []
        n_time, n_channels = data.shape
        
        if use_windowing:
            window_samples = min_window_length
            
            print(f"{tissue_type} - Window: {window_samples} samples, Step: {step_samples} samples")
            
            for ch in range(n_channels):
                channel_data = data[:, ch]
                channel_samples = []
                
                for start_idx in range(0, len(channel_data) - window_samples + 1, step_samples):
                    end_idx = start_idx + window_samples
                    window_data = channel_data[start_idx:end_idx]
                    
                    features = extract_comprehensive_features_from_window(window_data)
                    if features is not None:
                        channel_samples.append(features)
                    
                    if len(channel_samples) >= n_windows_per_channel:
                        break
                
                if max_samples_per_channel and len(channel_samples) > max_samples_per_channel:
                    indices = np.random.choice(len(channel_samples), max_samples_per_channel, replace=False)
                    channel_samples = [channel_samples[i] for i in sorted(indices)]
                
                for sample in channel_samples:
                    samples.append(sample)
                    channel_info.append(ch)
                    
        else:
            for ch in range(n_channels):
                channel_data = data[:, ch]
                features = extract_comprehensive_features_from_window(channel_data)
                if features is not None:
                    samples.append(features)
                    channel_info.append(ch)
        
        return samples, channel_info
    
    def balance_dataset(grey_df, white_df, grey_channel_info, white_channel_info, 
                       balance_method='hybrid', max_samples_total=2000):
        """Balance dataset to handle class imbalance"""
        print(f"Original data - Grey: {len(grey_df)}, White: {len(white_df)}")
        
        if balance_method == 'downsample':
            min_samples = min(len(grey_df), len(white_df))
            target_samples = min(min_samples, max_samples_total // 2)
            
            if len(grey_df) > target_samples:
                sample_indices = np.random.choice(len(grey_df), target_samples, replace=False)
                grey_df = grey_df.iloc[sample_indices].reset_index(drop=True)
                grey_channel_info = [grey_channel_info[i] for i in sample_indices]
            
            if len(white_df) > target_samples:
                sample_indices = np.random.choice(len(white_df), target_samples, replace=False)
                white_df = white_df.iloc[sample_indices].reset_index(drop=True)
                white_channel_info = [white_channel_info[i] for i in sample_indices]
                
        elif balance_method == 'upsample':
            max_samples = max(len(grey_df), len(white_df))
            target_samples = min(max_samples, max_samples_total // 2)
            
            if len(grey_df) < target_samples:
                n_upsample = target_samples - len(grey_df)
                upsample_indices = np.random.choice(len(grey_df), n_upsample, replace=True)
                grey_upsampled = grey_df.iloc[upsample_indices].reset_index(drop=True)
                grey_df = pd.concat([grey_df, grey_upsampled], ignore_index=True)
                grey_channel_info.extend([grey_channel_info[i] for i in upsample_indices])
            
            if len(white_df) < target_samples:
                n_upsample = target_samples - len(white_df)
                upsample_indices = np.random.choice(len(white_df), n_upsample, replace=True)
                white_upsampled = white_df.iloc[upsample_indices].reset_index(drop=True)
                white_df = pd.concat([white_df, white_upsampled], ignore_index=True)
                white_channel_info.extend([white_channel_info[i] for i in upsample_indices])
                
        elif balance_method == 'hybrid':
            total_samples = len(grey_df) + len(white_df)
            target_per_class = min(total_samples // 2, max_samples_total // 2)
            
            # Adjust grey matter
            if len(grey_df) > target_per_class:
                sample_indices = np.random.choice(len(grey_df), target_per_class, replace=False)
                grey_df = grey_df.iloc[sample_indices].reset_index(drop=True)
                grey_channel_info = [grey_channel_info[i] for i in sample_indices]
            elif len(grey_df) < target_per_class:
                n_upsample = target_per_class - len(grey_df)
                upsample_indices = np.random.choice(len(grey_df), n_upsample, replace=True)
                grey_upsampled = grey_df.iloc[upsample_indices].reset_index(drop=True)
                grey_df = pd.concat([grey_df, grey_upsampled], ignore_index=True)
                grey_channel_info.extend([grey_channel_info[i] for i in upsample_indices])
            
            # Adjust white matter
            if len(white_df) > target_per_class:
                sample_indices = np.random.choice(len(white_df), target_per_class, replace=False)
                white_df = white_df.iloc[sample_indices].reset_index(drop=True)
                white_channel_info = [white_channel_info[i] for i in sample_indices]
            elif len(white_df) < target_per_class:
                n_upsample = target_per_class - len(white_df)
                upsample_indices = np.random.choice(len(white_df), n_upsample, replace=True)
                white_upsampled = white_df.iloc[upsample_indices].reset_index(drop=True)
                white_df = pd.concat([white_df, white_upsampled], ignore_index=True)
                white_channel_info.extend([white_channel_info[i] for i in upsample_indices])
        
        print(f"Balanced data - Grey: {len(grey_df)}, White: {len(white_df)}")
        return grey_df, white_df, grey_channel_info, white_channel_info
    
    # ===============================
    # MAIN FEATURE EXTRACTION
    # ===============================
    
    print("\nExtracting features with channel separation...")
    grey_samples, grey_channel_info = create_channel_separated_samples(
        grey_matter_data, "Grey Matter"
    )
    white_samples, white_channel_info = create_channel_separated_samples(
        white_matter_data, "White Matter"
    )
    
    print(f"Grey matter samples created: {len(grey_samples)} from {len(set(grey_channel_info))} channels")
    print(f"White matter samples created: {len(white_samples)} from {len(set(white_channel_info))} channels")
    
    if len(grey_samples) == 0 or len(white_samples) == 0:
        return {'error': 'No samples created'}
    
    # Add contact depth features
    print("Adding contact depth features...")
    grey_samples = add_contact_depth_features(
        grey_samples, grey_channel_info, grey_matter_data.shape[1]
    )
    white_samples = add_contact_depth_features(
        white_samples, white_channel_info, white_matter_data.shape[1]
    )
    
    # Convert to DataFrames
    grey_df = pd.DataFrame(grey_samples)
    white_df = pd.DataFrame(white_samples)
    
    # Clean data
    def clean_dataframe(df, name):
        original_shape = df.shape
        numeric_cols = df.select_dtypes(include=[np.number]).columns
        
        if df[numeric_cols].isna().any().any():
            print(f"Found NaN values in {name}, filling with column medians")
            df[numeric_cols] = df[numeric_cols].fillna(df[numeric_cols].median())
        
        for col in numeric_cols:
            if np.any(np.isinf(df[col])):
                print(f"Found infinite values in {name} column {col}")
                finite_mask = np.isfinite(df[col])
                if np.any(finite_mask):
                    df.loc[~finite_mask, col] = df.loc[finite_mask, col].median()
                else:
                    df[col] = 0
        
        print(f"Cleaned {name}: {original_shape} -> {df.shape}")
        return df
    
    grey_df = clean_dataframe(grey_df, "grey matter features")
    white_df = clean_dataframe(white_df, "white matter features")
    
    # Balance dataset
    print("Balancing dataset...")
    grey_df, white_df, grey_channel_info, white_channel_info = balance_dataset(
        grey_df, white_df, grey_channel_info, white_channel_info,
        balance_method=balance_method, max_samples_total=2000
    )
    
    # Ensure common features
    feature_cols = [col for col in grey_df.columns]
    common_features = list(set(feature_cols).intersection(set(white_df.columns)))
    
    if len(common_features) == 0:
        return {'error': 'No common features'}
    
    print(f"Using {len(common_features)} common features")
    
    # Add labels and channel info
    grey_df['channel_id'] = grey_channel_info
    white_df['channel_id'] = [ch + grey_channels for ch in white_channel_info]
    grey_df['Matter'] = 'Grey'
    white_df['Matter'] = 'White'
    
    # Combine data
    combined_df = pd.concat([grey_df, white_df], ignore_index=True)
    
    # Prepare features and labels
    X = combined_df[common_features]
    y = combined_df['Matter'].map({'Grey': 1, 'White': 0})
    channel_ids = combined_df['channel_id'].values
    
    print(f"Final dataset: {len(X)} samples × {len(common_features)} features")
    print(f"Sample-to-feature ratio: {len(X) / len(common_features):.1f}:1")
    
    # Check for perfect separability
    def check_perfect_separability(X, y, feature_names):
        perfect_features = []
        
        for i, feature in enumerate(feature_names):
            feature_vals = X.iloc[:, i]
            grey_vals = feature_vals[y == 1]
            white_vals = feature_vals[y == 0]
            
            grey_min, grey_max = grey_vals.min(), grey_vals.max()
            white_min, white_max = white_vals.min(), white_vals.max()
            
            if grey_max < white_min or white_max < grey_min:
                perfect_features.append(feature)
                print(f"⚠️  PERFECT SEPARATION found in feature '{feature}':")
                print(f"   Grey range: [{grey_min:.3f}, {grey_max:.3f}]")
                print(f"   White range: [{white_min:.3f}, {white_max:.3f}]")
        
        return perfect_features
    
    perfect_features = check_perfect_separability(X, y, common_features)
    if perfect_features:
        print(f"\n🚨 WARNING: Found {len(perfect_features)} features with perfect separation!")
    
    # Scale features
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)
    
    # Train/test split
    if validation_type == 'strict':
        print("\n🔒 STRICT VALIDATION: Channel-wise train/test split")
        
        grey_channels_used = combined_df[combined_df['Matter'] == 'Grey']['channel_id'].unique()
        white_channels_used = combined_df[combined_df['Matter'] == 'White']['channel_id'].unique()
        
        print(f"Grey matter: {len(grey_channels_used)} unique channels")
        print(f"White matter: {len(white_channels_used)} unique channels")
        
        grey_train_channels, grey_test_channels = train_test_split(
            grey_channels_used, test_size=0.3, random_state=42
        )
        white_train_channels, white_test_channels = train_test_split(
            white_channels_used, test_size=0.3, random_state=42
        )
        
        train_mask = (
            (combined_df['Matter'] == 'Grey') & (combined_df['channel_id'].isin(grey_train_channels)) |
            (combined_df['Matter'] == 'White') & (combined_df['channel_id'].isin(white_train_channels))
        )
        test_mask = ~train_mask
        
        X_train, X_test = X_scaled[train_mask], X_scaled[test_mask]
        y_train, y_test = y[train_mask], y[test_mask]
        test_channel_ids = channel_ids[test_mask]
        
        print(f"Channel-wise split: Train={len(X_train)}, Test={len(X_test)}")
        
    else:
        X_train, X_test, y_train, y_test, train_idx, test_idx = train_test_split(
            X_scaled, y, np.arange(len(X)), test_size=0.3, random_state=42, stratify=y
        )
        test_channel_ids = channel_ids[test_idx]
    
    # ===============================
    # CLASSIFICATION
    # ===============================
    
    # Define classifiers
    classifiers = {
        'Logistic Regression': LogisticRegression(
            max_iter=1000, random_state=42, class_weight=class_weight
        ),
        'SVM (RBF kernel)': SVC(
            probability=True, random_state=42, C=1.0, class_weight=class_weight
        ),
        'Random Forest': RandomForestClassifier(
            n_estimators=100, random_state=42, max_depth=10, class_weight=class_weight
        ),
        'MLP Neural Network': MLPClassifier(
            max_iter=1000, random_state=42, hidden_layer_sizes=(50,), alpha=0.01
        ),
        'K-Nearest Neighbors': KNeighborsClassifier(n_neighbors=5),
        'LDA': LDA(),
        'Naive Bayes': GaussianNB()
    }
    
    results = {}
    
    # Cross-validation setup
    if validation_type == 'strict':
        print("\n🔒 Using channel-wise cross-validation...")
        cv = GroupKFold(n_splits=min(5, len(set(channel_ids))))
        groups = channel_ids
    else:
        cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
        groups = None
    
    plt.figure(figsize=(12, 10))
    
    print(f"\nTraining classifiers with {class_weight} class weighting...")
    
    for name, clf in classifiers.items():
        print(f"\nTraining {name}...")
        
        try:
            # Cross-validation
            if validation_type == 'strict' and len(set(channel_ids)) >= 3:
                cv_scores_f1 = cross_val_score(clf, X_scaled, y, cv=cv, groups=groups, scoring='f1')
                cv_scores_bal = cross_val_score(clf, X_scaled, y, cv=cv, groups=groups, scoring='balanced_accuracy')
            else:
                cv_scores_f1 = cross_val_score(clf, X_scaled, y, cv=StratifiedKFold(n_splits=3), scoring='f1')
                cv_scores_bal = cross_val_score(clf, X_scaled, y, cv=StratifiedKFold(n_splits=3), scoring='balanced_accuracy')
            
            # Train and test
            clf.fit(X_train, y_train)
            y_pred = clf.predict(X_test)
            
            # Calculate metrics
            test_f1 = f1_score(y_test, y_pred)
            test_balanced_acc = balanced_accuracy_score(y_test, y_pred)
            test_precision = precision_score(y_test, y_pred)
            test_recall = recall_score(y_test, y_pred)
            test_accuracy = accuracy_score(y_test, y_pred)
            
            # ROC curve
            if hasattr(clf, "predict_proba"):
                y_proba = clf.predict_proba(X_test)[:, 1]
                fpr, tpr, _ = roc_curve(y_test, y_proba)
                roc_auc = auc(fpr, tpr)
                plt.plot(fpr, tpr, label=f'{name} (AUC = {roc_auc:.3f})')
            else:
                y_proba = np.zeros_like(y_pred)
                roc_auc = 0.5
            
            # Store results
            results[name] = {
                'cv_f1': cv_scores_f1.mean(),
                'cv_f1_std': cv_scores_f1.std(),
                'cv_balanced_acc': cv_scores_bal.mean(),
                'cv_balanced_acc_std': cv_scores_bal.std(),
                'test_f1': test_f1,
                'test_balanced_acc': test_balanced_acc,
                'test_precision': test_precision,
                'test_recall': test_recall,
                'test_accuracy': test_accuracy,
                'roc_auc': roc_auc,
                'y_pred': y_pred,
                'y_proba': y_proba,
                'confusion_matrix': confusion_matrix(y_test, y_pred),
                'classification_report': classification_report(y_test, y_pred, output_dict=True)
            }
            
            print(f"  CV F1: {cv_scores_f1.mean():.3f} ± {cv_scores_f1.std():.3f}")
            print(f"  CV Balanced Acc: {cv_scores_bal.mean():.3f} ± {cv_scores_bal.std():.3f}")
            print(f"  Test F1: {test_f1:.3f}")
            print(f"  Test Precision: {test_precision:.3f}")
            print(f"  Test Recall: {test_recall:.3f}")
            print(f"  ROC AUC: {roc_auc:.3f}")
            
        except Exception as e:
            print(f"  Error: {str(e)}")
            results[name] = {'error': str(e)}
    
    # Finalize ROC plot
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'ROC Curves - {validation_type.upper()} Validation')
    plt.legend(loc="lower right")
    plt.grid(True, alpha=0.3)
    plt.savefig(os.path.join(plot_folder, f'roc_curves_{validation_type}.png'))
    plt.close()
    
    # ===============================
    # CHANNEL-LEVEL ANALYSIS
    # ===============================
    
    def aggregate_samples_to_channels(y_true, y_pred, y_proba, channel_ids, 
                                     aggregation_method='majority_vote'):
        """Aggregate sample-level predictions to channel-level"""
        unique_channels = np.unique(channel_ids)
        
        channel_true_labels = []
        channel_pred_labels = []
        channel_probabilities = []
        channel_confidence = []
        channel_sample_counts = []
        
        for ch_id in unique_channels:
            ch_mask = channel_ids == ch_id
            ch_true = y_true[ch_mask]
            ch_pred = y_pred[ch_mask]
            ch_proba = y_proba[ch_mask]
            
            # Get true label (should be consistent)
            true_label = np.bincount(ch_true).argmax()
            
            # Aggregate predictions based on method
            if aggregation_method == 'majority_vote':
                pred_label = np.bincount(ch_pred).argmax()
                avg_proba = np.mean(ch_proba)
                confidence = np.sum(ch_pred == pred_label) / len(ch_pred)
                
            elif aggregation_method == 'average_probability':
                avg_proba = np.mean(ch_proba)
                pred_label = 1 if avg_proba > 0.5 else 0
                confidence = abs(avg_proba - 0.5) * 2
                
            elif aggregation_method == 'weighted_vote':
                weights = np.abs(ch_proba - 0.5) * 2
                weighted_votes_0 = np.sum(weights[ch_pred == 0])
                weighted_votes_1 = np.sum(weights[ch_pred == 1])
                
                if weighted_votes_1 > weighted_votes_0:
                    pred_label = 1
                    confidence = weighted_votes_1 / (weighted_votes_0 + weighted_votes_1)
                else:
                    pred_label = 0
                    confidence = weighted_votes_0 / (weighted_votes_0 + weighted_votes_1)
                
                avg_proba = np.mean(ch_proba)
                
            elif aggregation_method == 'confidence_threshold':
                confidence_threshold = 0.7
                high_conf_mask = (ch_proba > confidence_threshold) | (ch_proba < (1 - confidence_threshold))
                
                if np.any(high_conf_mask):
                    high_conf_pred = ch_pred[high_conf_mask]
                    high_conf_proba = ch_proba[high_conf_mask]
                    pred_label = np.bincount(high_conf_pred).argmax()
                    avg_proba = np.mean(high_conf_proba)
                    confidence = np.sum(high_conf_mask) / len(ch_pred)
                else:
                    avg_proba = np.mean(ch_proba)
                    pred_label = 1 if avg_proba > 0.5 else 0
                    confidence = 0.5
            
            channel_true_labels.append(true_label)
            channel_pred_labels.append(pred_label)
            channel_probabilities.append(avg_proba)
            channel_confidence.append(confidence)
            channel_sample_counts.append(len(ch_pred))
        
        return {
            'channel_ids': unique_channels,
            'true_labels': np.array(channel_true_labels),
            'pred_labels': np.array(channel_pred_labels),
            'probabilities': np.array(channel_probabilities),
            'confidence': np.array(channel_confidence),
            'sample_counts': np.array(channel_sample_counts),
            'aggregation_method': aggregation_method
        }
    
    def evaluate_channel_level_performance(channel_results, plot_folder=None, validation_type='strict'):
        """Evaluate channel-level performance"""
        y_true_ch = channel_results['true_labels']
        y_pred_ch = channel_results['pred_labels']
        y_proba_ch = channel_results['probabilities']
        
        # Calculate metrics
        metrics = {
            'accuracy': accuracy_score(y_true_ch, y_pred_ch),
            'f1_score': f1_score(y_true_ch, y_pred_ch),
            'precision': precision_score(y_true_ch, y_pred_ch),
            'recall': recall_score(y_true_ch, y_pred_ch),
            'balanced_accuracy': balanced_accuracy_score(y_true_ch, y_pred_ch),
            'roc_auc': roc_auc_score(y_true_ch, y_proba_ch),
            'confusion_matrix': confusion_matrix(y_true_ch, y_pred_ch),
            'n_channels': len(y_true_ch),
            'n_grey_channels': np.sum(y_true_ch == 1),
            'n_white_channels': np.sum(y_true_ch == 0)
        }
        
        print(f"\n{'='*50}")
        print(f"CHANNEL-LEVEL PERFORMANCE ({channel_results['aggregation_method']})")
        print(f"{'='*50}")
        print(f"Total channels: {metrics['n_channels']}")
        print(f"Grey matter channels: {metrics['n_grey_channels']}")
        print(f"White matter channels: {metrics['n_white_channels']}")
        print(f"Accuracy: {metrics['accuracy']:.3f}")
        print(f"F1 Score: {metrics['f1_score']:.3f}")
        print(f"Precision: {metrics['precision']:.3f}")
        print(f"Recall: {metrics['recall']:.3f}")
        print(f"Balanced Accuracy: {metrics['balanced_accuracy']:.3f}")
        print(f"ROC AUC: {metrics['roc_auc']:.3f}")
        print(f"Confusion Matrix:")
        print(f"  [[TN={metrics['confusion_matrix'][0,0]}, FP={metrics['confusion_matrix'][0,1]}],")
        print(f"   [FN={metrics['confusion_matrix'][1,0]}, TP={metrics['confusion_matrix'][1,1]}]]")
        
        # Visualizations
        if plot_folder:
            # Confusion matrix
            plt.figure(figsize=(8, 6))
            sns.heatmap(metrics['confusion_matrix'], annot=True, fmt='d', cmap='Blues',
                       xticklabels=['White Matter', 'Grey Matter'],
                       yticklabels=['White Matter', 'Grey Matter'])
            plt.title(f'Channel-Level Confusion Matrix\n({channel_results["aggregation_method"]}, {validation_type})')
            plt.ylabel('True Label')
            plt.xlabel('Predicted Label')
            plt.tight_layout()
            plt.savefig(os.path.join(plot_folder, f'channel_confusion_matrix_{validation_type}_{channel_results["aggregation_method"]}.png'))
            plt.close()
            
            # Probability distribution
            plt.figure(figsize=(12, 5))
            
            plt.subplot(1, 2, 1)
            grey_proba = y_proba_ch[y_true_ch == 1]
            white_proba = y_proba_ch[y_true_ch == 0]
            
            plt.hist(white_proba, bins=20, alpha=0.7, label='White Matter Channels', color='red')
            plt.hist(grey_proba, bins=20, alpha=0.7, label='Grey Matter Channels', color='blue')
            plt.axvline(x=0.5, color='black', linestyle='--', label='Decision Threshold')
            plt.xlabel('Predicted Probability (Grey Matter)')
            plt.ylabel('Number of Channels')
            plt.title('Channel-Level Probability Distribution')
            plt.legend()
            plt.grid(True, alpha=0.3)
            
            # Confidence analysis
            plt.subplot(1, 2, 2)
            confidence = channel_results['confidence']
            correct_pred = (y_true_ch == y_pred_ch)
            
            plt.scatter(confidence[correct_pred], y_proba_ch[correct_pred], 
                       alpha=0.6, c='green', label='Correct Predictions', s=50)
            plt.scatter(confidence[~correct_pred], y_proba_ch[~correct_pred], 
                       alpha=0.6, c='red', label='Incorrect Predictions', s=50)
            plt.xlabel('Confidence Score')
            plt.ylabel('Predicted Probability')
            plt.title('Confidence vs Accuracy')
            plt.legend()
            plt.grid(True, alpha=0.3)
            
            plt.tight_layout()
            plt.savefig(os.path.join(plot_folder, f'channel_analysis_{validation_type}_{channel_results["aggregation_method"]}.png'))
            plt.close()
        
        return metrics
    
    def compare_aggregation_methods(y_true, y_pred, y_proba, channel_ids, plot_folder=None, validation_type='strict'):
        """Compare different aggregation methods"""
        methods = ['majority_vote', 'average_probability', 'weighted_vote', 'confidence_threshold']
        comparison_results = {}
        
        print(f"\n{'='*60}")
        print(f"COMPARING CHANNEL-LEVEL AGGREGATION METHODS")
        print(f"{'='*60}")
        
        for method in methods:
            print(f"\n--- {method.upper()} ---")
            
            channel_results = aggregate_samples_to_channels(
                y_true, y_pred, y_proba, channel_ids, aggregation_method=method
            )
            
            metrics = evaluate_channel_level_performance(
                channel_results, plot_folder, validation_type
            )
            
            comparison_results[method] = {
                'channel_results': channel_results,
                'metrics': metrics
            }
        
        # Create comparison table
        comparison_df = pd.DataFrame({
            method: {
                'F1 Score': results['metrics']['f1_score'],
                'Accuracy': results['metrics']['accuracy'],
                'Balanced Acc': results['metrics']['balanced_accuracy'],
                'Precision': results['metrics']['precision'],
                'Recall': results['metrics']['recall'],
                'ROC AUC': results['metrics']['roc_auc']
            }
            for method, results in comparison_results.items()
        }).round(3)
        
        print(f"\n{'='*60}")
        print("AGGREGATION METHODS COMPARISON")
        print(f"{'='*60}")
        print(comparison_df.to_string())
        
        # Save comparison results
        if plot_folder:
            comparison_df.to_csv(os.path.join(plot_folder, f'channel_aggregation_comparison_{validation_type}.csv'))
            
            # Visualization
            plt.figure(figsize=(12, 8))
            
            metrics_to_plot = ['F1 Score', 'Balanced Acc', 'Precision', 'Recall', 'ROC AUC']
            x = np.arange(len(methods))
            width = 0.15
            
            for i, metric in enumerate(metrics_to_plot):
                values = [comparison_df.loc[metric, method] for method in methods]
                plt.bar(x + i*width, values, width, label=metric, alpha=0.8)
            
            plt.xlabel('Aggregation Methods')
            plt.ylabel('Performance Score')
            plt.title(f'Channel-Level Aggregation Methods Comparison ({validation_type})')
            plt.xticks(x + width*2, [m.replace('_', ' ').title() for m in methods], rotation=15)
            plt.legend()
            plt.grid(True, alpha=0.3)
            plt.tight_layout()
            plt.savefig(os.path.join(plot_folder, f'aggregation_methods_comparison_{validation_type}.png'))
            plt.close()
        
        # Recommend best method
        best_method = comparison_df.loc['F1 Score'].idxmax()
        best_f1 = comparison_df.loc['F1 Score', best_method]
        
        print(f"\n🏆 RECOMMENDED METHOD: {best_method}")
        print(f"   Best F1 Score: {best_f1:.3f}")
        
        return comparison_results, best_method
    
    # ===============================
    # EXECUTE CHANNEL-LEVEL ANALYSIS
    # ===============================
    
    # Find best classifier
    valid_results = {name: result for name, result in results.items() if 'error' not in result}
    if not valid_results:
        return {'error': 'No valid classifier results'}
    
    best_classifier_name = max(valid_results.items(), key=lambda x: x[1]['test_f1'])[0]
    best_result = valid_results[best_classifier_name]
    
    print(f"\n{'='*60}")
    print(f"CHANNEL-LEVEL ANALYSIS USING BEST CLASSIFIER: {best_classifier_name}")
    print(f"Best Sample-Level F1: {best_result['test_f1']:.3f}")
    print(f"{'='*60}")
    
    # Get predictions from best classifier
    y_pred_best = best_result['y_pred']
    y_proba_best = best_result['y_proba']
    
    # Compare aggregation methods
    channel_comparison, best_aggregation_method = compare_aggregation_methods(
        y_test, y_pred_best, y_proba_best, test_channel_ids, 
        plot_folder=plot_folder, validation_type=validation_type
    )
    
    # Final channel-level results with best method
    final_channel_results = aggregate_samples_to_channels(
        y_test, y_pred_best, y_proba_best, test_channel_ids, 
        aggregation_method=best_aggregation_method
    )
    
    final_channel_metrics = evaluate_channel_level_performance(
        final_channel_results, plot_folder, validation_type
    )
    
    # ===============================
    # CREATE PERFORMANCE SUMMARY
    # ===============================
    
    # Sample-level summary
    sample_performance = []
    for name, result in valid_results.items():
        sample_performance.append({
            'Classifier': name,
            'CV F1': f"{result['cv_f1']:.3f} ± {result['cv_f1_std']:.3f}",
            'Test F1': result['test_f1'],
            'Test Precision': result['test_precision'],
            'Test Recall': result['test_recall'],
            'Test Balanced Acc': result['test_balanced_acc'],
            'ROC AUC': result['roc_auc']
        })
    
    sample_performance_df = pd.DataFrame(sample_performance)
    sample_performance_df.to_csv(os.path.join(plot_folder, f'sample_performance_{validation_type}.csv'), index=False)
    
    # Final summary
    print(f"\n{'='*80}")
    print(f"FINAL SUMMARY - {validation_type.upper()} VALIDATION")
    print(f"{'='*80}")
    print(f"Perfect separation features: {len(perfect_features)}")
    print(f"Data shape: {X.shape}")
    print(f"Sample-to-feature ratio: {len(X) / len(common_features):.1f}:1")
    print(f"Class distribution: Grey={np.sum(y==1)}, White={np.sum(y==0)}")
    
    print(f"\nSAMPLE-LEVEL PERFORMANCE:")
    print(f"Best classifier: {best_classifier_name}")
    print(f"Best F1 score: {best_result['test_f1']:.3f}")
    print(f"Best ROC AUC: {best_result['roc_auc']:.3f}")
    
    print(f"\nCHANNEL-LEVEL PERFORMANCE:")
    print(f"Best aggregation method: {best_aggregation_method}")
    print(f"Channel-level F1 score: {final_channel_metrics['f1_score']:.3f}")
    print(f"Channel-level ROC AUC: {final_channel_metrics['roc_auc']:.3f}")
    print(f"Total channels analyzed: {final_channel_metrics['n_channels']}")
    
    # Performance improvement
    improvement = final_channel_metrics['f1_score'] - best_result['test_f1']
    print(f"F1 improvement (channel vs sample): {improvement:+.3f}")
    
    if best_result['test_f1'] > 0.95:
        print("⚠️  WARNING: Sample-level accuracy > 95% may indicate data leakage!")
    elif best_result['test_f1'] > 0.90:
        print("⚠️  CAUTION: Sample-level accuracy > 90% - verify results")
    else:
        print("✅ Sample-level accuracy seems reasonable")
    
    if final_channel_metrics['f1_score'] > 0.8:
        print("🎉 Excellent channel-level performance!")
    elif final_channel_metrics['f1_score'] > 0.6:
        print("✅ Good channel-level performance")
    else:
        print("⚠️  Channel-level performance could be improved")
    
    # ===============================
    # RETURN COMPREHENSIVE RESULTS
    # ===============================
    
    return {
        'validation_type': validation_type,
        'balance_method': balance_method,
        'class_weight': class_weight,
        
        # Data info
        'n_samples': len(X),
        'n_features': len(common_features),
        'sample_feature_ratio': len(X) / len(common_features),
        'perfect_features': perfect_features,
        'feature_names': common_features,
        
        # Sample-level results
        'sample_level': {
            'classifier_results': valid_results,
            'performance_summary': sample_performance_df,
            'best_classifier': best_classifier_name,
            'best_metrics': {
                'f1_score': best_result['test_f1'],
                'balanced_accuracy': best_result['test_balanced_acc'],
                'roc_auc': best_result['roc_auc'],
                'precision': best_result['test_precision'],
                'recall': best_result['test_recall']
            }
        },
        
        # Channel-level results
        'channel_level': {
            'comparison_results': channel_comparison,
            'best_aggregation_method': best_aggregation_method,
            'final_results': final_channel_results,
            'final_metrics': final_channel_metrics,
            'improvement_over_sample': improvement
        },
        
        # Files saved
        'output_folder': plot_folder,
        'files_generated': [
            f'roc_curves_{validation_type}.png',
            f'sample_performance_{validation_type}.csv',
            f'channel_aggregation_comparison_{validation_type}.csv',
            f'aggregation_methods_comparison_{validation_type}.png',
            f'channel_confusion_matrix_{validation_type}_{best_aggregation_method}.png',
            f'channel_analysis_{validation_type}_{best_aggregation_method}.png'
        ]
    }

In [None]:
results_3 = improved_extract_and_classify_features_complete(
    seizure_data_grey_new, seizure_data_white_new, 
    plot_folder=f'result/{seizureID}/wg_classification_strict_ictal_data',
    fs=p66_data.samplingRate,
    use_windowing=True,  # Use windowing for better sample size
    window_overlap=0.5,
    min_window_length=1000,
    validation_type='strict')  # Strict channel-wise validation