In [None]:
import numpy as np
# At the start of your notebook
from IPython.display import clear_output
import gc

# After heavy computations
clear_output(wait=True)
gc.collect()
import pickle
import os
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from itertools import combinations
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, roc_curve, auc
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from scipy import stats
from scipy.signal import welch

RESULT_FOLDER = "result"
MODEL_FOLDER = "model"
model_names = ['Wavenet']  # 'CNN1D', 'Wavenet', 'S4', 'Resnet'
# Do batch analysis to find the best hyperparameters
seizures = [1, 2, 3, 5, 7]
thresholds = [0.8]
smooth_windows = [80]

In [None]:
# Get the data from one patient:

p66_data = pickle.load(open('data/P66/seizure_All_combined.pkl', "rb"))

In [None]:
def extract_grey_matter_channels(matter: pd.DataFrame):
    """Extract grey matter channels from Matter file"""
    # Get grey matter channels
    selected_matter = matter[matter['MatterType'].isin(['G', 'A'])]
    grey_matter_channels = selected_matter['ChannelNumber'].values
    
    return grey_matter_channels

p66_data.matter = pd.read_csv('data/P66/matter.csv')
all_channels = np.arange(0, p66_data.channelNumber)
grey_channel = extract_grey_matter_channels(p66_data.matter) - 1
white_channel = np.setdiff1d(all_channels, grey_channel)

seizure_data_grey = p66_data.postictal_transformed[:,grey_channel,:,:]
seizure_data_white = p66_data.postictal_transformed[:,white_channel,:,:]

In [None]:
def visualize_features(seizure_data_grey, seizure_data_white, feature_names, time_aggregation='mean', plot_folder='result/wgfeatures'):
    """
    Visualize features from grey and white matter data to find best separating features,
    with outlier removal and proper time-axis aggregation.
    
    Parameters:
    -----------
    seizure_data_grey : np.ndarray
        Grey matter data with shape [samples, channels, time, features]
    seizure_data_white : np.ndarray
        White matter data with shape [samples, channels, time, features]
    feature_names : list
        List of feature names
    time_aggregation : str, optional
        Method for aggregating across time axis ('mean' or 'median'), default is 'median'
    plot_folder : str, optional
        Folder path to save all plots, default is 'results/wgfeatures'
    
    Returns:
    --------
    dict
        Dictionary containing the best features and their metrics
    """
    # Create the plot folder if it doesn't exist
    os.makedirs(plot_folder, exist_ok=True)
    
    # Get the number of features
    n_features = seizure_data_grey.shape[3]
    
    # Ensure feature_names has right length
    if len(feature_names) != n_features:
        print(f"Warning: Feature names length ({len(feature_names)}) doesn't match number of features ({n_features})")
        feature_names = [f"Feature {i+1}" for i in range(n_features)]
    
    # 1. Aggregate data across time dimension
    print(f"Aggregating data across time using {time_aggregation}...")
    
    if time_aggregation == 'mean':
        # Calculate mean across time dimension
        grey_agg = np.mean(seizure_data_grey, axis=2)  # [samples, channels, features]
        white_agg = np.mean(seizure_data_white, axis=2)
    elif time_aggregation == 'median':
        # Calculate median across time dimension
        grey_agg = np.median(seizure_data_grey, axis=2)
        white_agg = np.median(seizure_data_white, axis=2)
    else:
        raise ValueError("time_aggregation must be either 'mean' or 'median'")
    
    # Reshape to 2D: [samples*channels, features]
    grey_reshaped = grey_agg.reshape(-1, n_features)
    white_reshaped = white_agg.reshape(-1, n_features)
    
    print(f"Grey matter data shape after time aggregation: {grey_reshaped.shape}")
    print(f"White matter data shape after time aggregation: {white_reshaped.shape}")
    
    # Create DataFrames
    df_grey = pd.DataFrame(grey_reshaped, columns=feature_names)
    df_white = pd.DataFrame(white_reshaped, columns=feature_names)
    
    # 2. Remove outliers from each feature (using IQR method)
    print("Removing outliers...")
    
    def remove_outliers(df):
        df_clean = df.copy()
        for feature in feature_names:
            Q1 = df_clean[feature].quantile(0.25)
            Q3 = df_clean[feature].quantile(0.75)
            IQR = Q3 - Q1
            lower_bound = Q1 - 1.5 * IQR
            upper_bound = Q3 + 1.5 * IQR
            
            # Count outliers before removal
            outliers = df_clean[(df_clean[feature] < lower_bound) | (df_clean[feature] > upper_bound)]
            outlier_count = len(outliers)
            
            # Remove outliers
            df_clean = df_clean[(df_clean[feature] >= lower_bound) & (df_clean[feature] <= upper_bound)]
            
            if outlier_count > 0:
                print(f"  Removed {outlier_count} outliers from {feature} ({outlier_count/len(df)*100:.2f}%)")
        
        return df_clean
    
    # Apply outlier removal
    df_grey_clean = remove_outliers(df_grey)
    df_white_clean = remove_outliers(df_white)
    
    print(f"After outlier removal: Grey matter: {len(df_grey_clean)} samples, White matter: {len(df_white_clean)} samples")
    
    # Add matter labels for combined dataframe
    df_grey_clean['Matter'] = 'Grey'
    df_white_clean['Matter'] = 'White'
    df_clean = pd.concat([df_grey_clean, df_white_clean])
    
    # 3. Create boxplots for each feature
    nrows = (n_features + 2) // 3  # Adjust rows for better layout
    fig_width = min(20, n_features * 6)
    plt.figure(figsize=(fig_width, nrows * 5))
    
    for i, feature in enumerate(feature_names):
        plt.subplot(nrows, 3, i+1)
        sns.boxplot(x='Matter', y=feature, data=df_clean)
        plt.title(feature)
        plt.xticks(rotation=45)
        plt.tight_layout()
    
    plt.savefig(os.path.join(plot_folder, 'feature_boxplots.png'))
    plt.close()
    
    # 4. Feature distributions
    plt.figure(figsize=(fig_width, nrows * 5))
    
    for i, feature in enumerate(feature_names):
        plt.subplot(nrows, 3, i+1)
        sns.histplot(df_grey_clean[feature], color='blue', label='Grey', alpha=0.5, kde=True)
        sns.histplot(df_white_clean[feature], color='red', label='White', alpha=0.5, kde=True)
        plt.title(f'{feature} Distribution')
        plt.legend()
        plt.tight_layout()
    
    plt.savefig(os.path.join(plot_folder, 'feature_distributions.png'))
    plt.close()
    
    # 5. Calculate ROC curves and AUC for each feature
    def calculate_roc(feature_grey, feature_white):
        X = np.concatenate([feature_grey, feature_white])
        # Grey = 1, White = 0
        y = np.concatenate([np.ones(len(feature_grey)), np.zeros(len(feature_white))])
        fpr, tpr, _ = roc_curve(y, X)
        roc_auc = auc(fpr, tpr)
        return fpr, tpr, roc_auc
    
    plt.figure(figsize=(12, 10))
    auc_values = []
    
    for i, feature in enumerate(feature_names):
        grey_feature = df_grey_clean[feature].values
        white_feature = df_white_clean[feature].values
        fpr, tpr, roc_auc = calculate_roc(grey_feature, white_feature)
        plt.plot(fpr, tpr, label=f'{feature} (AUC = {roc_auc:.3f})')
        auc_values.append(roc_auc)
    
    plt.plot([0, 1], [0, 1], 'k--')  # Random classifier line
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curves for Individual Features')
    plt.legend(loc="lower right", fontsize=8)
    plt.savefig(os.path.join(plot_folder, 'roc_curves.png'))
    plt.close()
    
    # Identify the best feature based on AUC
    best_feature_idx = np.argmax(auc_values)
    best_feature_name = feature_names[best_feature_idx]
    print(f"Best single feature: {best_feature_name} with AUC = {auc_values[best_feature_idx]:.3f}")
    
    # 6. Create scatter plots for feature pairs
    # Create pair-wise AUC matrix
    pair_auc_matrix = np.zeros((n_features, n_features))
    best_pair_auc = 0
    best_pair = None
    
    # Create a subfolder for scatter plots
    scatter_folder = os.path.join(plot_folder, 'scatter_plots')
    os.makedirs(scatter_folder, exist_ok=True)
    
    # Only do top feature combinations if there are many features
    feature_indices = range(n_features)
    if n_features > 10:  # If more than 10 features, limit pairs to top 5 individual features
        top_indices = np.argsort(auc_values)[-5:]
        print("Too many features, limiting pairs to combinations of top 5 individual features")
        feature_pairs = list(combinations(top_indices, 2))
    else:
        feature_pairs = list(combinations(feature_indices, 2))
    
    for i, j in feature_pairs:
        feature_i = feature_names[i]
        feature_j = feature_names[j]
        
        # Create scatter plot
        plt.figure(figsize=(8, 6))
        plt.scatter(df_grey_clean[feature_i], df_grey_clean[feature_j], 
                  alpha=0.5, label='Grey Matter', color='blue')
        plt.scatter(df_white_clean[feature_i], df_white_clean[feature_j], 
                  alpha=0.5, label='White Matter', color='red')
        plt.xlabel(feature_i)
        plt.ylabel(feature_j)
        plt.title(f'{feature_i} vs {feature_j}')
        plt.legend()
        plt.savefig(os.path.join(scatter_folder, f'feature_scatter_{i+1}_{j+1}.png'))
        plt.close()
        
        # Calculate discrimination power using LDA
        X_grey = df_grey_clean[[feature_i, feature_j]].values
        X_white = df_white_clean[[feature_i, feature_j]].values
        X = np.vstack([X_grey, X_white])
        y = np.concatenate([np.ones(len(X_grey)), np.zeros(len(X_white))])
        
        lda = LDA(n_components=1)
        X_lda = lda.fit_transform(X, y)
        
        # Calculate ROC for the LDA projection
        fpr, tpr, _ = roc_curve(y, X_lda.ravel())
        roc_auc = auc(fpr, tpr)
        
        pair_auc_matrix[i, j] = roc_auc
        pair_auc_matrix[j, i] = roc_auc  # Mirror the matrix
        
        if roc_auc > best_pair_auc:
            best_pair_auc = roc_auc
            best_pair = (feature_i, feature_j)
    
    # Fill diagonal with single feature AUCs
    for i in range(n_features):
        pair_auc_matrix[i, i] = auc_values[i]
    
    print(f"Best feature pair: {best_pair[0]} and {best_pair[1]} with AUC = {best_pair_auc:.3f}")
    
    # 7. Feature pair heatmap
    plt.figure(figsize=(12, 10))
    # Create shorter feature names for the heatmap if needed
    short_names = [name[:10] + '...' if len(name) > 10 else name for name in feature_names]
    
    sns.heatmap(pair_auc_matrix, annot=True, cmap='YlGnBu', fmt='.3f',
                xticklabels=short_names, yticklabels=short_names)
    plt.title('AUC Values for Feature Pairs (diagonal = single feature)')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(os.path.join(plot_folder, 'feature_pair_auc_heatmap.png'))
    plt.close()
    
    # 8. LDA for all features
    X_grey_all = df_grey_clean[feature_names].values
    X_white_all = df_white_clean[feature_names].values
    X_all = np.vstack([X_grey_all, X_white_all])
    y_all = np.concatenate([np.ones(len(X_grey_all)), np.zeros(len(X_white_all))])
    
    lda_all = LDA(n_components=1)
    X_lda_all = lda_all.fit_transform(X_all, y_all)
    
    # Plot LDA projection
    plt.figure(figsize=(10, 6))
    sns.histplot(X_lda_all[y_all==1], color='blue', label='Grey Matter', alpha=0.5, kde=True)
    sns.histplot(X_lda_all[y_all==0], color='red', label='White Matter', alpha=0.5, kde=True)
    plt.title('LDA Projection Using All Features')
    plt.legend()
    plt.savefig(os.path.join(plot_folder, 'lda_all_features.png'))
    plt.close()
    
    # Calculate ROC for all-feature LDA projection
    fpr_all, tpr_all, _ = roc_curve(y_all, X_lda_all.ravel())
    roc_auc_all = auc(fpr_all, tpr_all)
    print(f"AUC using all features with LDA: {roc_auc_all:.3f}")
    
    # 9. Feature importance from LDA
    feature_importance = np.abs(lda_all.coef_[0])
    importance_df = pd.DataFrame({'Feature': feature_names, 'Importance': feature_importance})
    importance_df = importance_df.sort_values('Importance', ascending=False)
    
    plt.figure(figsize=(12, 6))
    sns.barplot(x='Feature', y='Importance', data=importance_df)
    plt.title('Feature Importance from LDA')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig(os.path.join(plot_folder, 'feature_importance.png'))
    plt.close()
    
    print("Top features by LDA importance:")
    print(importance_df)
    
    # Save results to CSV
    df_results = pd.DataFrame({
        'Feature': feature_names,
        'AUC': auc_values,
        'LDA_Importance': feature_importance
    })
    df_results = df_results.sort_values('AUC', ascending=False)
    df_results.to_csv(os.path.join(plot_folder, 'feature_results.csv'), index=False)
    
    # Return results
    return {
        'best_single_feature': best_feature_name,
        'best_single_feature_auc': auc_values[best_feature_idx],
        'best_feature_pair': best_pair,
        'best_feature_pair_auc': best_pair_auc,
        'all_features_auc': roc_auc_all,
        'feature_importance': importance_df,
        'auc_values': dict(zip(feature_names, auc_values)),
        'pair_auc_matrix': pd.DataFrame(pair_auc_matrix, index=feature_names, columns=feature_names),
        'grey_data_clean': df_grey_clean,
        'white_data_clean': df_white_clean
    }

In [None]:
results = visualize_features(seizure_data_grey, seizure_data_white, p66_data.feature_names,plot_folder='result/wgfeatures/postictal')

In [None]:
from datasetConstruct import EDFData
p66_raw = pickle.load(open('data/P65/seizure_SZ1.pkl', "rb"))

In [None]:
raw_grey = p66_raw.interictal[:, grey_channel]
raw_white = p66_raw.interictal[:, white_channel]

In [None]:
def analyze_raw_timeseries(grey_matter_data, white_matter_data, channel_names=None, 
                          plot_folder='result/raw_wg_comparison', fs=250):
    """
    Analyze and visualize differences between grey and white matter using full time series data,
    with special focus on frequency band differentiation.
    
    Parameters:
    -----------
    grey_matter_data : np.ndarray
        Grey matter data with shape [time, channels]
    white_matter_data : np.ndarray
        White matter data with shape [time, channels]
    channel_names : list, optional
        List of channel names, default is None (will use indices)
    plot_folder : str, optional
        Folder path to save all plots
    fs : int, optional
        Sampling frequency in Hz, default is 250
    
    Returns:
    --------
    dict
        Dictionary containing analysis results
    """
    # Create the plot folder if it doesn't exist
    os.makedirs(plot_folder, exist_ok=True)
    
    # Get dimensions
    grey_time, grey_channels = grey_matter_data.shape
    white_time, white_channels = white_matter_data.shape
    
    print(f"Grey matter data shape: [{grey_time}, {grey_channels}]")
    print(f"White matter data shape: [{white_time}, {white_channels}]")
    
    # Create channel names if not provided
    if channel_names is None:
        channel_names = [f"Channel {i+1}" for i in range(max(grey_channels, white_channels))]
    
    # 1. Basic time series visualization
    # Sample a shorter segment for visualization (e.g., 10 seconds)
    sample_duration = min(10 * fs, min(grey_time, white_time))
    
    plt.figure(figsize=(15, 10))
    # Plot a few random channels from each group
    channels_to_plot = min(5, min(grey_channels, white_channels))
    grey_indices = np.random.choice(grey_channels, channels_to_plot, replace=False)
    white_indices = np.random.choice(white_channels, channels_to_plot, replace=False)
    
    for i in range(channels_to_plot):
        plt.subplot(channels_to_plot, 2, i*2+1)
        plt.plot(grey_matter_data[:sample_duration, grey_indices[i]])
        plt.title(f"Grey Matter: {channel_names[grey_indices[i]]}")
        
        plt.subplot(channels_to_plot, 2, i*2+2)
        plt.plot(white_matter_data[:sample_duration, white_indices[i]])
        plt.title(f"White Matter: {channel_names[white_indices[i]]}")
    
    plt.tight_layout()
    plt.savefig(os.path.join(plot_folder, 'raw_time_series_samples.png'))
    plt.close()
    
    # 2. Calculate summary statistics for each channel
    grey_stats = {
        'mean': np.mean(grey_matter_data, axis=0),
        'std': np.std(grey_matter_data, axis=0),
        'median': np.median(grey_matter_data, axis=0),
        'iqr': np.percentile(grey_matter_data, 75, axis=0) - np.percentile(grey_matter_data, 25, axis=0),
        'skew': stats.skew(grey_matter_data, axis=0),
        'kurtosis': stats.kurtosis(grey_matter_data, axis=0),
        'range': np.max(grey_matter_data, axis=0) - np.min(grey_matter_data, axis=0),
        'rms': np.sqrt(np.mean(grey_matter_data**2, axis=0))
    }
    
    white_stats = {
        'mean': np.mean(white_matter_data, axis=0),
        'std': np.std(white_matter_data, axis=0),
        'median': np.median(white_matter_data, axis=0),
        'iqr': np.percentile(white_matter_data, 75, axis=0) - np.percentile(white_matter_data, 25, axis=0),
        'skew': stats.skew(white_matter_data, axis=0),
        'kurtosis': stats.kurtosis(white_matter_data, axis=0),
        'range': np.max(white_matter_data, axis=0) - np.min(white_matter_data, axis=0),
        'rms': np.sqrt(np.mean(white_matter_data**2, axis=0))
    }
    
    # 3. Define frequency bands
    freq_bands = {
        'delta': (0.5, 4),
        'theta': (4, 8),
        'alpha': (8, 13),
        'beta': (13, 30),
        'low_gamma': (30, 60),
        'high_gamma': (60, 100),
        'ripple': (100, 200)
    }
    
    # 4. Extract band power for each channel
    def calculate_band_powers(data, fs, bands):
        """Calculate power in each frequency band for each channel"""
        n_channels = data.shape[1]
        band_powers = {band: np.zeros(n_channels) for band in bands}
        
        for ch in range(n_channels):
            # Calculate PSD
            f, pxx = welch(data[:, ch], fs=fs, nperseg=min(1024, len(data)), scaling='density')
            
            # Calculate power in each band
            for band, (low_freq, high_freq) in bands.items():
                band_idx = np.logical_and(f >= low_freq, f <= high_freq)
                if np.any(band_idx):
                    band_powers[band][ch] = np.mean(pxx[band_idx])
        
        return band_powers
    
    # Calculate band powers for grey and white matter
    grey_band_powers = calculate_band_powers(grey_matter_data, fs, freq_bands)
    white_band_powers = calculate_band_powers(white_matter_data, fs, freq_bands)
    
    # 5. Calculate band power ratios (can be more discriminative than absolute powers)
    band_pairs = [
        ('theta', 'delta'),  # theta/delta ratio
        ('alpha', 'theta'),  # alpha/theta ratio
        ('beta', 'alpha'),   # beta/alpha ratio
        ('low_gamma', 'beta'),  # low gamma/beta ratio
        ('high_gamma', 'low_gamma'),  # high gamma/low gamma ratio
        ('ripple', 'high_gamma')  # ripple/high gamma ratio
    ]
    
    grey_band_ratios = {}
    white_band_ratios = {}
    
    for num_band, denom_band in band_pairs:
        ratio_name = f"{num_band}/{denom_band}"
        grey_band_ratios[ratio_name] = grey_band_powers[num_band] / (grey_band_powers[denom_band] + 1e-10)  # avoid div by zero
        white_band_ratios[ratio_name] = white_band_powers[num_band] / (white_band_powers[denom_band] + 1e-10)  # avoid div by zero
    
    # 6. Combine features for ROC analysis
    all_features = {}
    
    # Add statistical features
    for stat in grey_stats:
        all_features[stat] = {
            'grey': grey_stats[stat],
            'white': white_stats[stat]
        }
    
    # Add band power features
    for band in freq_bands:
        all_features[f"power_{band}"] = {
            'grey': grey_band_powers[band],
            'white': white_band_powers[band]
        }
    
    # Add band power ratio features
    for ratio_name in grey_band_ratios:
        all_features[f"ratio_{ratio_name}"] = {
            'grey': grey_band_ratios[ratio_name],
            'white': white_band_ratios[ratio_name]
        }
    
    # 7. Calculate ROC for each feature
    roc_results = {}
    
    for feature_name, feature_data in all_features.items():
        grey_feature = feature_data['grey']
        white_feature = feature_data['white']
        
        # Skip features with NaN or inf values
        if np.any(np.isnan(grey_feature)) or np.any(np.isnan(white_feature)) or \
           np.any(np.isinf(grey_feature)) or np.any(np.isinf(white_feature)):
            continue
            
        X = np.concatenate([grey_feature, white_feature])
        y = np.concatenate([np.ones_like(grey_feature), np.zeros_like(white_feature)])
        
        # Calculate ROC
        fpr, tpr, _ = roc_curve(y, X)
        roc_auc = auc(fpr, tpr)
        
        roc_results[feature_name] = {
            'fpr': fpr,
            'tpr': tpr,
            'auc': roc_auc
        }
    
    # 8. Plot ROC curves for frequency band features
    plt.figure(figsize=(12, 10))
    
    # First, plot ROC curves for band powers
    for band in freq_bands:
        feature_name = f"power_{band}"
        if feature_name in roc_results:
            plt.plot(
                roc_results[feature_name]['fpr'], 
                roc_results[feature_name]['tpr'], 
                label=f"{feature_name} (AUC = {roc_results[feature_name]['auc']:.3f})"
            )
    
    plt.plot([0, 1], [0, 1], 'k--')  # random classifier line
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curves for Frequency Band Powers')
    plt.legend(loc="lower right")
    plt.grid(True, alpha=0.3)
    plt.savefig(os.path.join(plot_folder, 'frequency_bands_roc.png'))
    plt.close()
    
    # 9. Plot ROC curves for band power ratios
    plt.figure(figsize=(12, 10))
    
    for ratio_name in grey_band_ratios:
        feature_name = f"ratio_{ratio_name}"
        if feature_name in roc_results:
            plt.plot(
                roc_results[feature_name]['fpr'], 
                roc_results[feature_name]['tpr'], 
                label=f"{feature_name} (AUC = {roc_results[feature_name]['auc']:.3f})"
            )
    
    plt.plot([0, 1], [0, 1], 'k--')  # random classifier line
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curves for Frequency Band Ratios')
    plt.legend(loc="lower right")
    plt.grid(True, alpha=0.3)
    plt.savefig(os.path.join(plot_folder, 'frequency_ratios_roc.png'))
    plt.close()
    
    # 10. Plot ROC curves for best features overall
    # Sort features by AUC
    sorted_features = sorted(roc_results.items(), key=lambda x: x[1]['auc'], reverse=True)
    
    # Plot top 10 features (or fewer if there are less than 10)
    top_n = min(10, len(sorted_features))
    plt.figure(figsize=(12, 10))
    
    for i in range(top_n):
        feature_name, feature_data = sorted_features[i]
        plt.plot(
            feature_data['fpr'], 
            feature_data['tpr'], 
            label=f"{feature_name} (AUC = {feature_data['auc']:.3f})"
        )
    
    plt.plot([0, 1], [0, 1], 'k--')  # random classifier line
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curves for Top Features')
    plt.legend(loc="lower right")
    plt.grid(True, alpha=0.3)
    plt.savefig(os.path.join(plot_folder, 'top_features_roc.png'))
    plt.close()
    
    # 11. Create feature comparison table
    feature_comparison = []
    for feature_name, feature_data in sorted_features:
        # Calculate mean and std for grey and white matter
        if feature_name in all_features:
            grey_mean = np.mean(all_features[feature_name]['grey'])
            grey_std = np.std(all_features[feature_name]['grey'])
            white_mean = np.mean(all_features[feature_name]['white'])
            white_std = np.std(all_features[feature_name]['white'])
            
            feature_comparison.append({
                'Feature': feature_name,
                'AUC': feature_data['auc'],
                'Grey Mean': grey_mean,
                'Grey STD': grey_std,
                'White Mean': white_mean,
                'White STD': white_std,
                'Mean Diff': grey_mean - white_mean,
                'Relative Diff (%)': 100 * (grey_mean - white_mean) / (abs(white_mean) + 1e-10)
            })
    
    comparison_df = pd.DataFrame(feature_comparison)
    comparison_df.to_csv(os.path.join(plot_folder, 'feature_comparison.csv'))
    
    # 12. Visualize the most discriminative features
    top_features = comparison_df.head(min(5, len(comparison_df)))
    
    plt.figure(figsize=(15, 10))
    for i, (_, row) in enumerate(top_features.iterrows()):
        feature_name = row['Feature']
        
        plt.subplot(2, 3, i+1)
        sns.histplot(all_features[feature_name]['grey'], color='blue', label='Grey Matter', kde=True, alpha=0.5)
        sns.histplot(all_features[feature_name]['white'], color='red', label='White Matter', kde=True, alpha=0.5)
        plt.title(f"{feature_name} (AUC: {row['AUC']:.3f})")
        plt.legend()
    
    plt.tight_layout()
    plt.savefig(os.path.join(plot_folder, 'top_features_distribution.png'))
    plt.close()
    
    # 13. Average power spectra comparison
    # Calculate average PSD across all channels
    def calculate_avg_psd(data, fs):
        """Calculate average PSD across all channels"""
        n_channels = data.shape[1]
        all_pxx = []
        
        for ch in range(n_channels):
            f, pxx = welch(data[:, ch], fs=fs, nperseg=min(2048, len(data)))
            all_pxx.append(pxx)
        
        avg_pxx = np.mean(np.array(all_pxx), axis=0)
        return f, avg_pxx
    
    f_grey, pxx_grey = calculate_avg_psd(grey_matter_data, fs)
    f_white, pxx_white = calculate_avg_psd(white_matter_data, fs)
    
    plt.figure(figsize=(12, 8))
    plt.semilogy(f_grey, pxx_grey, label='Grey Matter', color='blue')
    plt.semilogy(f_white, pxx_white, label='White Matter', color='red')
    
    # Shade the frequency bands
    colors = ['lightblue', 'lightgreen', 'lightyellow', 'lightpink', 'lavender', 'peachpuff', 'mistyrose']
    i = 0
    for band, (low, high) in freq_bands.items():
        if low <= max(f_grey) and high >= min(f_grey):
            # Adjust boundaries to be within plot range
            low_adj = max(low, min(f_grey))
            high_adj = min(high, max(f_grey))
            
            # Find nearest indices
            low_idx = np.argmin(np.abs(f_grey - low_adj))
            high_idx = np.argmin(np.abs(f_grey - high_adj))
            
            # Calculate max y value at these frequencies for proper shading
            y_max = max(np.max(pxx_grey[low_idx:high_idx+1]), np.max(pxx_white[low_idx:high_idx+1]))
            
            # Add shaded region
            plt.fill_between(
                f_grey[low_idx:high_idx+1], 
                0, y_max, 
                color=colors[i % len(colors)], 
                alpha=0.3, 
                label=band
            )
            i += 1
    
    plt.title('Average Power Spectral Density')
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Power Spectral Density (log scale)')
    plt.xlim([0,150])
    plt.legend()
    plt.grid(True, which="both", ls="--", alpha=0.3)
    plt.savefig(os.path.join(plot_folder, 'average_psd_with_bands.png'))
    plt.close()
    
    # Return results summary
    return {
        'feature_comparison': comparison_df,
        'roc_results': roc_results,
        'grey_band_powers': grey_band_powers,
        'white_band_powers': white_band_powers,
        'grey_band_ratios': grey_band_ratios,
        'white_band_ratios': white_band_ratios
    }

In [None]:
channel_names = p66_data.matter['ElectrodeName'].values

In [None]:
results = analyze_raw_timeseries(raw_grey, raw_white, channel_names, plot_folder='result/65sz1/raw_wg_comparison',fs=p66_raw.samplingRate)

In [None]:
def extract_and_classify_features(grey_matter_data, white_matter_data, 
                                  plot_folder='result/wg_classification', fs=250):
    """
    Extract features from raw time series, perform PCA, and classify grey vs white matter.
    
    Parameters:
    -----------
    grey_matter_data : np.ndarray
        Grey matter data with shape [time, channels]
    white_matter_data : np.ndarray
        White matter data with shape [time, channels]
    plot_folder : str, optional
        Folder path to save all plots, default is 'results/wg_classification'
    fs : int, optional
        Sampling frequency in Hz, default is 250
    
    Returns:
    --------
    dict
        Dictionary containing classification results and PCA components
    """
    # Create the plot folder if it doesn't exist
    os.makedirs(plot_folder, exist_ok=True)
    
    # Get dimensions
    grey_time, grey_channels = grey_matter_data.shape
    white_time, white_channels = white_matter_data.shape
    
    print(f"Grey matter data shape: [{grey_time}, {grey_channels}]")
    print(f"White matter data shape: [{white_time}, {white_channels}]")
    
    # 1. Extract features from raw data
    print("Extracting features...")
    
    # Define frequency bands
    freq_bands = {
        'delta': (0.5, 4),
        'theta': (4, 8),
        'alpha': (8, 13),
        'beta': (13, 30),
        'low_gamma': (30, 60),
        'high_gamma': (60, 150)
    }
    
    # Function to extract features from a channel
    def extract_channel_features(signal):
        features = {}
        
        # Time domain statistical features
        features['mean'] = np.mean(signal)
        features['std'] = np.std(signal)
        features['median'] = np.median(signal)
        features['iqr'] = np.percentile(signal, 75) - np.percentile(signal, 25)
        features['skew'] = stats.skew(signal)
        features['kurtosis'] = stats.kurtosis(signal)
        features['range'] = np.max(signal) - np.min(signal)
        features['rms'] = np.sqrt(np.mean(signal**2))
        features['zero_crossings'] = np.sum(np.diff(np.signbit(signal).astype(int)) != 0)
        
        # Frequency domain features
        f, pxx = welch(signal, fs=fs, nperseg=min(1024, len(signal)))
        
        # Band powers
        for band_name, (low_freq, high_freq) in freq_bands.items():
            band_idx = np.logical_and(f >= low_freq, f <= high_freq)
            if np.any(band_idx):
                features[f'power_{band_name}'] = np.mean(pxx[band_idx])
            else:
                features[f'power_{band_name}'] = 0
        
        # Band power ratios
        band_pairs = [
            ('theta', 'delta'),   # theta/delta ratio
            ('alpha', 'theta'),   # alpha/theta ratio
            ('beta', 'alpha'),    # beta/alpha ratio
            ('low_gamma', 'beta'),  # low gamma/beta ratio
            ('high_gamma', 'low_gamma')  # high gamma/low gamma ratio
        ]
        
        for num_band, denom_band in band_pairs:
            features[f'ratio_{num_band}_{denom_band}'] = (
                features[f'power_{num_band}'] / (features[f'power_{denom_band}'] + 1e-10)  # avoid div by zero
            )
        
        # Total power
        features['total_power'] = np.sum(pxx)
        
        # Spectral edge frequency (95%)
        total_power = np.sum(pxx)
        power_sum = 0
        for i, power in enumerate(pxx):
            power_sum += power
            if power_sum >= 0.95 * total_power:
                features['spectral_edge_freq'] = f[i]
                break
        
        # Spectral entropy
        pxx_norm = pxx / np.sum(pxx)
        features['spectral_entropy'] = -np.sum(pxx_norm * np.log2(pxx_norm + 1e-10))
        
        return features
    
    # Extract features for each channel
    grey_features_list = []
    for ch in range(grey_channels):
        features = extract_channel_features(grey_matter_data[:, ch])
        grey_features_list.append(features)
    
    white_features_list = []
    for ch in range(white_channels):
        features = extract_channel_features(white_matter_data[:, ch])
        white_features_list.append(features)
    
    # Convert list of dictionaries to DataFrames
    grey_df = pd.DataFrame(grey_features_list)
    white_df = pd.DataFrame(white_features_list)
    
    print(f"Extracted {len(grey_df.columns)} features from each channel")
    
    # 2. Clean data - check for NaNs or infinite values
    # Handle NaNs in grey_df
    if grey_df.isna().any().any():
        print(f"Found NaN values in grey matter features, filling with column medians")
        grey_df = grey_df.fillna(grey_df.median())
    
    # Handle NaNs in white_df
    if white_df.isna().any().any():
        print(f"Found NaN values in white matter features, filling with column medians")
        white_df = white_df.fillna(white_df.median())
    
    # Handle infinites in grey_df
    inf_mask_grey = np.isinf(grey_df.values)
    if np.any(inf_mask_grey):
        print(f"Found infinite values in grey matter features, replacing with column medians")
        for col_idx in range(grey_df.shape[1]):
            col_inf_mask = inf_mask_grey[:, col_idx]
            if np.any(col_inf_mask):
                grey_df.iloc[col_inf_mask, col_idx] = grey_df.iloc[~col_inf_mask, col_idx].median()
    
    # Handle infinites in white_df
    inf_mask_white = np.isinf(white_df.values)
    if np.any(inf_mask_white):
        print(f"Found infinite values in white matter features, replacing with column medians")
        for col_idx in range(white_df.shape[1]):
            col_inf_mask = inf_mask_white[:, col_idx]
            if np.any(col_inf_mask):
                white_df.iloc[col_inf_mask, col_idx] = white_df.iloc[~col_inf_mask, col_idx].median()
    
    # 3. Make sure both dataframes have the same columns
    common_features = list(set(grey_df.columns).intersection(set(white_df.columns)))
    grey_df = grey_df[common_features]
    white_df = white_df[common_features]
    
    # 4. Add labels and combine data
    grey_df['Matter'] = 'Grey'
    white_df['Matter'] = 'White'
    combined_df = pd.concat([grey_df, white_df])
    
    # 5. Prepare data for PCA and classification
    X = combined_df.drop('Matter', axis=1)
    y = combined_df['Matter'].map({'Grey': 1, 'White': 0})
    
    feature_names = X.columns
    
    # 6. Scale the features
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)
    
    # 7. Apply PCA
    # Try different numbers of components
    variance_threshold = 0.95  # Capture 95% of variance
    pca = PCA()
    pca.fit(X_scaled)
    
    # Plot explained variance
    explained_variance = pca.explained_variance_ratio_
    cumulative_variance = np.cumsum(explained_variance)
    
    plt.figure(figsize=(10, 6))
    plt.bar(range(1, len(explained_variance) + 1), explained_variance, alpha=0.7, label='Individual')
    plt.step(range(1, len(cumulative_variance) + 1), cumulative_variance, where='mid', label='Cumulative')
    plt.axhline(y=variance_threshold, color='r', linestyle='--', label=f'{variance_threshold*100}% Variance')
    
    # Find number of components for threshold
    n_components = np.argmax(cumulative_variance >= variance_threshold) + 1
    plt.axvline(x=n_components, color='r', linestyle='--')
    
    plt.xlabel('Number of Principal Components')
    plt.ylabel('Explained Variance Ratio')
    plt.title('PCA Explained Variance')
    plt.legend(loc='best')
    plt.grid(True, alpha=0.3)
    plt.savefig(os.path.join(plot_folder, 'pca_explained_variance.png'))
    plt.close()
    
    print(f"Number of components needed for {variance_threshold*100}% variance: {n_components}")
    
    # 8. Apply PCA with the determined number of components
    pca = PCA(n_components=min(n_components, len(feature_names)))
    X_pca = pca.fit_transform(X_scaled)
    
    # 9. Visualize the PCA results (first 2 or 3 components)
    # 2D scatter plot (first 2 components)
    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(X_pca[:, 0], X_pca[:, 1], c=y, cmap='coolwarm', alpha=0.7, s=50)
    plt.colorbar(scatter, label='Grey Matter (1) vs White Matter (0)')
    plt.xlabel(f'Principal Component 1 ({explained_variance[0]:.1%} variance)')
    plt.ylabel(f'Principal Component 2 ({explained_variance[1]:.1%} variance)')
    plt.title('PCA of Grey vs White Matter Features (2D)')
    plt.grid(True, alpha=0.3)
    plt.savefig(os.path.join(plot_folder, 'pca_2d_scatter.png'))
    plt.close()
    
    # If we have at least 3 components, create a 3D scatter plot
    if X_pca.shape[1] >= 3:
        from mpl_toolkits.mplot3d import Axes3D
        
        fig = plt.figure(figsize=(12, 10))
        ax = fig.add_subplot(111, projection='3d')
        
        scatter = ax.scatter(
            X_pca[:, 0], X_pca[:, 1], X_pca[:, 2],
            c=y, cmap='coolwarm', alpha=0.7, s=50
        )
        
        ax.set_xlabel(f'PC1 ({explained_variance[0]:.1%})')
        ax.set_ylabel(f'PC2 ({explained_variance[1]:.1%})')
        ax.set_zlabel(f'PC3 ({explained_variance[2]:.1%})')
        
        plt.colorbar(scatter, label='Grey Matter (1) vs White Matter (0)')
        plt.title('PCA of Grey vs White Matter Features (3D)')
        plt.savefig(os.path.join(plot_folder, 'pca_3d_scatter.png'))
        plt.close()
    
    # 10. Feature importance in PCA
    # Get the top 20 features that contribute most to the first 2 PCs
    pc1_importance = pd.DataFrame({
        'Feature': feature_names,
        'PC1_Importance': np.abs(pca.components_[0])
    }).sort_values('PC1_Importance', ascending=False)
    
    pc2_importance = pd.DataFrame({
        'Feature': feature_names,
        'PC2_Importance': np.abs(pca.components_[1])
    }).sort_values('PC2_Importance', ascending=False)
    
    # Plot top 10 features for PC1 and PC2
    plt.figure(figsize=(12, 10))
    
    plt.subplot(2, 1, 1)
    sns.barplot(x='PC1_Importance', y='Feature', data=pc1_importance.head(10))
    plt.title('Top 10 Features for Principal Component 1')
    
    plt.subplot(2, 1, 2)
    sns.barplot(x='PC2_Importance', y='Feature', data=pc2_importance.head(10))
    plt.title('Top 10 Features for Principal Component 2')
    
    plt.tight_layout()
    plt.savefig(os.path.join(plot_folder, 'pca_feature_importance.png'))
    plt.close()
    
    # 11. Classification Analysis
    # Split data into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(
        X_scaled, y, test_size=0.3, random_state=42, stratify=y
    )
    
    # Define classifiers to evaluate
    classifiers = {
        'Logistic Regression': LogisticRegression(max_iter=1000, random_state=42),
        'SVM (RBF kernel)': SVC(probability=True, random_state=42),
        'Random Forest': RandomForestClassifier(n_estimators=100, random_state=42),
        'MLP Neural Network': MLPClassifier(max_iter=1000, random_state=42),
        'K-Nearest Neighbors': KNeighborsClassifier(n_neighbors=5),
        'LDA': LDA()
    }
    
    # Evaluate classifiers
    results = {}
    cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    
    plt.figure(figsize=(12, 10))
    
    for i, (name, clf) in enumerate(classifiers.items()):
        # Train and evaluate with cross-validation
        cv_scores = cross_val_score(clf, X_scaled, y, cv=cv, scoring='accuracy')
        
        # Train on the training set
        clf.fit(X_train, y_train)
        
        # Predict on the test set
        y_pred = clf.predict(X_test)
        accuracy = accuracy_score(y_test, y_pred)
        
        # If classifier supports predict_proba, calculate ROC
        if hasattr(clf, "predict_proba"):
            y_proba = clf.predict_proba(X_test)[:, 1]
            fpr, tpr, _ = roc_curve(y_test, y_proba)
            roc_auc = auc(fpr, tpr)
            
            # Plot ROC curve
            plt.plot(fpr, tpr, label=f'{name} (AUC = {roc_auc:.3f})')
        
        # Store results
        results[name] = {
            'cv_accuracy': cv_scores.mean(),
            'cv_std': cv_scores.std(),
            'test_accuracy': accuracy,
            'confusion_matrix': confusion_matrix(y_test, y_pred),
            'classification_report': classification_report(y_test, y_pred, output_dict=True),
            'importance': None
        }
        
        # Store feature importance if available
        if hasattr(clf, "feature_importances_"):
            results[name]['importance'] = clf.feature_importances_
        elif hasattr(clf, "coef_"):
            results[name]['importance'] = np.abs(clf.coef_[0]) if clf.coef_.ndim > 1 else np.abs(clf.coef_)
    
    # Finalize ROC plot
    plt.plot([0, 1], [0, 1], 'k--')  # random classifier line
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curves for Different Classifiers')
    plt.legend(loc="lower right")
    plt.grid(True, alpha=0.3)
    plt.savefig(os.path.join(plot_folder, 'classifier_roc_curves.png'))
    plt.close()
    
    # 12. Summarize classifier performance
    performance_summary = []
    
    for name, result in results.items():
        performance_summary.append({
            'Classifier': name,
            'CV Accuracy': f"{result['cv_accuracy']:.3f} ± {result['cv_std']:.3f}",
            'Test Accuracy': f"{result['test_accuracy']:.3f}",
            'Precision (Grey)': f"{result['classification_report'].get('1', {}).get('precision', 0):.3f}",
            'Recall (Grey)': f"{result['classification_report'].get('1', {}).get('recall', 0):.3f}",
            'F1 Score (Grey)': f"{result['classification_report'].get('1', {}).get('f1-score', 0):.3f}"
        })
    
    performance_df = pd.DataFrame(performance_summary)
    performance_df.to_csv(os.path.join(plot_folder, 'classifier_performance.csv'), index=False)
    
    # 13. Visualize classifier performance
    plt.figure(figsize=(12, 6))
    performance_plot_df = pd.DataFrame([
        {
            'Classifier': row['Classifier'],
            'CV Accuracy': float(row['CV Accuracy'].split(' ')[0])
        } 
        for row in performance_summary
    ])
    
    ax = sns.barplot(x='Classifier', y='CV Accuracy', data=performance_plot_df)
    
    # Add value labels
    for i, p in enumerate(ax.patches):
        ax.annotate(
            f"{p.get_height():.3f}", 
            (p.get_x() + p.get_width() / 2., p.get_height()), 
            ha='center', va='bottom'
        )
    
    plt.xlabel('Classifier')
    plt.ylabel('Cross-Validation Accuracy')
    plt.title('Classifier Performance Comparison')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig(os.path.join(plot_folder, 'classifier_performance.png'))
    plt.close()
    
    # 14. Plot feature importance for the best classifier
    best_classifier = max(results.items(), key=lambda x: x[1]['cv_accuracy'])
    best_name, best_result = best_classifier
    
    if best_result['importance'] is not None:
        # Create a DataFrame with feature importance
        importance_df = pd.DataFrame({
            'Feature': feature_names,
            'Importance': best_result['importance']
        }).sort_values('Importance', ascending=False)
        
        # Plot top 15 features
        plt.figure(figsize=(12, 8))
        sns.barplot(x='Importance', y='Feature', data=importance_df.head(15))
        plt.title(f'Top 15 Features Importance for {best_name}')
        plt.tight_layout()
        plt.savefig(os.path.join(plot_folder, 'best_classifier_feature_importance.png'))
        plt.close()
        
        # Save all feature importance
        importance_df.to_csv(os.path.join(plot_folder, 'feature_importance.csv'), index=False)
    
    # 15. Visualize feature distributions
    # Select top 6 features based on importance
    if best_result['importance'] is not None:
        top_features = importance_df.head(6)['Feature'].values
        
        plt.figure(figsize=(15, 10))
        for i, feature in enumerate(top_features):
            plt.subplot(2, 3, i+1)
            sns.histplot(grey_df[feature], color='blue', label='Grey', kde=True, alpha=0.5)
            sns.histplot(white_df[feature], color='red', label='White', kde=True, alpha=0.5)
            plt.title(feature)
            plt.legend()
        
        plt.tight_layout()
        plt.savefig(os.path.join(plot_folder, 'top_features_distribution.png'))
        plt.close()
    
    # Return results
    return {
        'grey_features': grey_df,
        'white_features': white_df,
        'pca': pca,
        'pca_data': X_pca,
        'classifier_results': results,
        'performance_summary': performance_df,
        'best_classifier': best_name,
        'feature_names': feature_names
    }

In [None]:
classification_results = extract_and_classify_features(raw_grey, raw_white, plot_folder='result/65sz1/wg_classification', fs=p66_raw.samplingRate)