In [None]:
import os
import glob
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import cross_val_score, StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import (accuracy_score, f1_score, precision_score, recall_score, 
                           balanced_accuracy_score, roc_auc_score, roc_curve, auc,
                           confusion_matrix, classification_report)
from scipy import stats
from scipy.fftpack import fft
from itertools import combinations
import warnings
warnings.filterwarnings('ignore')


class MultiPatientClassifier:
    """
    Channel-Level标准化的Grey/White Matter分类系统
    
    核心思想：
    1. 每个channel作为一个分类单位
    2. 对每个channel的时间窗口特征进行channel-specific标准化
    3. 使用平均概率来分类每个channel
    4. 验证时计算每个patient的channel-level accuracy
    """
    
    def __init__(self, processed_folder, output_folder='results'):
        self.processed_folder = processed_folder
        self.output_folder = output_folder
        self.patients_data = {}
        self.channel_features_data = {}  # 改为channel-centric存储
        self.normalized_channel_features_data = {}
        
        self.classifiers = {
            'Logistic Regression': LogisticRegression(max_iter=1000, random_state=42),
            'SVM': SVC(probability=True, random_state=42),
            'Random Forest': RandomForestClassifier(n_estimators=100, random_state=42),
            'MLP': MLPClassifier(max_iter=1000, random_state=42, hidden_layer_sizes=(50,)),
            'KNN': KNeighborsClassifier(n_neighbors=5),
            'LDA': LDA(),
            'Naive Bayes': GaussianNB()
        }
        
        os.makedirs(output_folder, exist_ok=True)
        
        print("Initializing Channel-Level Normalized Classifier")
        print(f"Processed folder: {processed_folder}")
        print(f"Output folder: {output_folder}")
    
    def load_all_patients(self):
        """Load all patients from the processed folder."""
        pkl_files = glob.glob(os.path.join(self.processed_folder, "P*_processed.pkl"))
        print(f"\nFound {len(pkl_files)} patient files")
        for pkl_file in pkl_files:
            try:
                with open(pkl_file, 'rb') as f:
                    data = pickle.load(f)
                pid = data['patient_id']
                self.patients_data[pid] = data
                duration = data['processing_summary']['total_duration_seconds'] / 60
                print(f"✓ {pid}: {len(data['recordings'])} recordings, {duration:.1f} minutes")
            except Exception as e:
                print(f"✗ Failed to load {os.path.basename(pkl_file)}: {e}")
        print(f"Successfully loaded {len(self.patients_data)} patients")
        return bool(self.patients_data)
    
    def extract_electrode_classification(self, matter_data):
        """从matter数据提取电极分类"""
        
        matter_columns = ['MatterType', 'matter', 'Matter', 'mattertype', 'tissue_type', 'type']
        matter_col = None
        
        for col in matter_columns:
            if col in matter_data.columns:
                matter_col = col
                break
        
        if matter_col is None:
            raise ValueError(f"Matter type column not found. Available: {matter_data.columns.tolist()}")
        
        matter_values = matter_data[matter_col].astype(str).str.lower()
        
        grey_mask = matter_values.isin(['G', 'g', 'Grey', 'grey', 'Gray', 'gray'])
        white_mask = matter_values.isin(['W', 'w', 'White', 'white'])
        
        if np.sum(grey_mask) == 0 or np.sum(white_mask) == 0:
            print(f"    G/W format not found, trying other methods...")
            matter_values_lower = matter_values.str.lower()
            
            if np.sum(grey_mask) == 0:
                grey_patterns = ['grey', 'gray', 'cortex', 'cortical']
                grey_mask = matter_values_lower.str.contains('|'.join(grey_patterns), na=False, case=False)
            
            if np.sum(white_mask) == 0:
                white_patterns = ['white']
                white_mask = matter_values_lower.str.contains('|'.join(white_patterns), na=False, case=False)
        
        grey_indices = matter_data.index[grey_mask].tolist()
        white_indices = matter_data.index[white_mask].tolist()
        
        classification_info = {
            'matter_column': matter_col,
            'total_electrodes': len(matter_data),
            'grey_electrodes': len(grey_indices),
            'white_electrodes': len(white_indices),
            'grey_indices': grey_indices,
            'white_indices': white_indices,
            'matter_distribution': matter_data[matter_col].value_counts().to_dict()
        }
        
        return grey_indices, white_indices, classification_info
    
    def extract_signal_features(self, signal, fs):
        """提取单个信号的特征"""
        
        if len(signal) < 10 or np.all(signal == 0):
            return None
        
        features = {}
        
        # 时域特征
        features['std'] = np.std(signal)
        features['mad'] = np.median(np.abs(signal - np.median(signal)))
        features['range'] = np.ptp(signal)
        features['iqr'] = np.percentile(signal, 75) - np.percentile(signal, 25)
        features['rms'] = np.sqrt(np.mean(signal**2))
        
        # 信号复杂度
        features['area'] = np.sum(np.abs(signal))
        
        # 频域特征
        try:
            n_fft = len(signal)
            windowed_signal = signal * np.hamming(n_fft)
            fft_vals = fft(windowed_signal)
            fft_mag = np.abs(fft_vals[:n_fft//2])
            freqs = np.fft.fftfreq(n_fft, 1/fs)[:n_fft//2]
            
            bands = {
                'delta': (0.5, 4),
                'theta': (4, 8),
                'alpha': (8, 13),
                'beta': (13, 30),
                'gamma': (30, 100),
                'high_gamma': (100, min(200, fs/2))
            }
            
            total_power = np.sum(fft_mag**2)
            for band_name, (low, high) in bands.items():
                band_mask = (freqs >= low) & (freqs <= high)
                if np.any(band_mask):
                    band_power = np.sum(fft_mag[band_mask]**2)
                    features[f'power_{band_name}'] = band_power
                    features[f'rel_power_{band_name}'] = band_power / total_power if total_power > 0 else 0
                else:
                    features[f'power_{band_name}'] = 0
                    features[f'rel_power_{band_name}'] = 0
            
            features['total_power'] = total_power
                
        except Exception as e:
            bands = ['delta', 'theta', 'alpha', 'beta', 'gamma', 'high_gamma']
            for band in bands:
                features[f'power_{band}'] = 0
                features[f'rel_power_{band}'] = 0
            features['total_power'] = 0
            features['peak_frequency'] = 0
            features['spectral_centroid'] = 0
            features['spectral_entropy'] = 0
        
        return features
    
    def extract_channel_centric_features(self, use_windowing=True, 
                                       window_size_ms=500, step_size_ms=250, 
                                       max_windows_per_channel=200):
        """
        提取以channel为中心的特征
        
        核心思想：
        - 每个channel作为一个分类单位
        - 每个channel有多个时间窗口的特征
        - 存储结构：{channel_id: {features: [], label: 0/1, patient_id: str}}
        """
        
        print(f"\n🔄 Extracting Channel-Centric Features...")
        
        channel_id = 0  # 全局channel ID
        
        for patient_id, patient_data in self.patients_data.items():
            print(f"\n处理 {patient_id}...")
            
            matter_data = patient_data['matter_data']
            recordings = patient_data['recordings']
            
            # 提取电极分类
            try:
                grey_indices, white_indices, classification_info = self.extract_electrode_classification(matter_data)
            except Exception as e:
                print(f"  ✗ Electrode classification failed: {e}")
                continue
            
            print(f"  Electrode Classification: {classification_info['grey_electrodes']} grey, {classification_info['white_electrodes']} white")
            
            if len(grey_indices) == 0 or len(white_indices) == 0:
                print(f"  ✗ Missing Grey or White Matter Electrode")
                continue
            
            # 合并所有recordings的数据
            all_grey_data = []
            all_white_data = []
            fs = recordings[0]['sampling_rate'] if recordings else 512
            
            for recording in recordings:
                neural_data = recording['neural_data_processed']
                grey_data = neural_data[:, grey_indices]
                white_data = neural_data[:, white_indices]
                all_grey_data.append(grey_data)
                all_white_data.append(white_data)
            
            combined_grey = np.vstack(all_grey_data) if all_grey_data else np.array([])
            combined_white = np.vstack(all_white_data) if all_white_data else np.array([])
            
            print(f"  Merging...: Grey {combined_grey.shape}, White {combined_white.shape}")
            
            # 处理Grey Matter Channels
            for ch_idx, electrode_idx in enumerate(grey_indices):
                channel_data = combined_grey[:, ch_idx]
                
                # 为这个channel提取多个时间窗口的特征
                channel_features = self._extract_features_for_single_channel(
                    channel_data, fs, use_windowing, window_size_ms, step_size_ms, max_windows_per_channel
                )
                
                if len(channel_features) > 0:
                    self.channel_features_data[channel_id] = {
                        'features': channel_features,  # List of feature dicts from different windows
                        'label': 1,  # Grey matter
                        'patient_id': patient_id,
                        'electrode_idx': electrode_idx,
                        'channel_idx': ch_idx,
                        'matter_type': 'grey'
                    }
                    channel_id += 1
            
            # 处理White Matter Channels
            for ch_idx, electrode_idx in enumerate(white_indices):
                channel_data = combined_white[:, ch_idx]
                
                channel_features = self._extract_features_for_single_channel(
                    channel_data, fs, use_windowing, window_size_ms, step_size_ms, max_windows_per_channel
                )
                
                if len(channel_features) > 0:
                    self.channel_features_data[channel_id] = {
                        'features': channel_features,
                        'label': 0,  # White matter
                        'patient_id': patient_id,
                        'electrode_idx': electrode_idx,
                        'channel_idx': ch_idx,
                        'matter_type': 'white'
                    }
                    channel_id += 1
        
        print(f"\n✅ Extracted features for {len(self.channel_features_data)} channels")
        
        # 统计信息
        grey_channels = sum(1 for ch in self.channel_features_data.values() if ch['label'] == 1)
        white_channels = sum(1 for ch in self.channel_features_data.values() if ch['label'] == 0)
        
        print(f"   Grey matter channels: {grey_channels}")
        print(f"   White matter channels: {white_channels}")
        
        return True
    
    def _extract_features_for_single_channel(self, channel_data, fs, use_windowing, 
                                           window_size_ms, step_size_ms, max_windows_per_channel):
        """为单个channel提取特征"""
        
        channel_features = []
        
        if use_windowing:
            # 时间窗口方法
            window_samples = int(window_size_ms * fs / 1000)
            step_samples = int(step_size_ms * fs / 1000)
            n_time = len(channel_data)
            
            window_count = 0
            for start in range(0, n_time - window_samples + 1, step_samples):
                if window_count >= max_windows_per_channel:
                    break
                    
                end = start + window_samples
                window_data = channel_data[start:end]
                
                features = self.extract_signal_features(window_data, fs)
                if features is not None:
                    channel_features.append(features)
                    window_count += 1
        else:
            # 整个信号作为一个特征
            features = self.extract_signal_features(channel_data, fs)
            if features is not None:
                channel_features.append(features)
        
        return channel_features
    
    def apply_channel_specific_normalization(self, normalization_method='robust'):
        """
        对每个channel的特征进行channel-specific标准化
        
        核心思想：
        1. 每个channel有多个时间窗口的特征
        2. 在每个channel内部，对这些时间窗口的特征进行标准化
        3. 这样消除了channel内部的系统性差异，保留了grey/white matter的差异
        """
        
        print(f"\n🔄 Applying Channel-Specific Normalization ({normalization_method})...")
        
        if not self.channel_features_data:
            print("  ❌ No channel features found. Run extract_channel_centric_features() first.")
            return False
        
        # 获取特征名称（从第一个channel的第一个窗口获取）
        first_channel = list(self.channel_features_data.values())[0]
        if len(first_channel['features']) == 0:
            print("  ❌ No features found in channels.")
            return False
        
        feature_names = list(first_channel['features'][0].keys())
        print(f"  Normalizing {len(feature_names)} features across {len(self.channel_features_data)} channels")
        
        normalization_stats = {}
        
        # 为每个channel进行标准化
        for channel_id, channel_data in self.channel_features_data.items():
            patient_id = channel_data['patient_id']
            matter_type = channel_data['matter_type']
            
            # 将这个channel的所有窗口特征转换为DataFrame
            features_list = channel_data['features']
            if len(features_list) == 0:
                continue
                
            features_df = pd.DataFrame(features_list)
            
            # 对这个channel的特征进行标准化
            normalized_features_df = self._normalize_features(features_df, normalization_method)
            
            # 计算标准化统计信息
            before_stats = self._get_feature_stats(features_df)
            after_stats = self._get_feature_stats(normalized_features_df)
            
            # 转换回list of dicts
            normalized_features_list = normalized_features_df.to_dict('records')
            
            # 存储标准化后的数据
            normalized_channel_data = channel_data.copy()
            normalized_channel_data['features'] = normalized_features_list
            normalized_channel_data['normalization_method'] = normalization_method
            normalized_channel_data['normalization_stats'] = {
                'before': before_stats,
                'after': after_stats
            }
            
            self.normalized_channel_features_data[channel_id] = normalized_channel_data
            
            # 收集统计信息
            key = f"{patient_id}_{matter_type}"
            if key not in normalization_stats:
                normalization_stats[key] = {
                    'channels': 0,
                    'before_std_mean': 0,
                    'after_std_mean': 0
                }
            
            normalization_stats[key]['channels'] += 1
            normalization_stats[key]['before_std_mean'] += before_stats['std_mean']
            normalization_stats[key]['after_std_mean'] += after_stats['std_mean']
        
        # 打印统计信息
        print(f"\n📊 Normalization Statistics by Patient and Matter Type:")
        print("-" * 70)
        print(f"{'Patient_Matter':<20} {'Channels':<10} {'Before_StdMean':<15} {'After_StdMean':<15}")
        print("-" * 70)
        
        for key, stats in normalization_stats.items():
            n_channels = stats['channels']
            before_mean = stats['before_std_mean'] / n_channels
            after_mean = stats['after_std_mean'] / n_channels
            print(f"{key:<20} {n_channels:<10} {before_mean:<15.3f} {after_mean:<15.3f}")
        
        print(f"\n✅ Channel-specific normalization completed!")
        print(f"   Normalized {len(self.normalized_channel_features_data)} channels")
        
        return True
    
    def _normalize_features(self, features_df, method='robust', outlier_clip=True, iqr_multiplier=1.5):
        """
        应用特定的标准化方法，在标准化前使用IQR方法过滤outliers
        
        Parameters:
        -----------
        features_df : pd.DataFrame
            特征数据框
        method : str
            标准化方法 ('robust', 'standard', 'minmax', 'quantile')
        outlier_clip : bool
            是否在标准化前clip outliers
        iqr_multiplier : float
            IQR倍数，用于定义outlier阈值 (1.5为标准值，3.0为极端值)
        """
        
        normalized_df = features_df.copy()
        outlier_stats = {}  # 记录outlier统计信息
        
        for column in features_df.columns:
            values = features_df[column].values
            original_values = values.copy()
            
            # Step 1: 使用IQR方法clip outliers
            if outlier_clip:
                q25 = np.percentile(values, 25)
                q75 = np.percentile(values, 75)
                iqr = q75 - q25
                
                if iqr > 0:  # 避免除零错误
                    lower_bound = q25 - iqr_multiplier * iqr
                    upper_bound = q75 + iqr_multiplier * iqr
                    
                    # 记录outlier统计
                    n_outliers = np.sum((values < lower_bound) | (values > upper_bound))
                    outlier_stats[column] = {
                        'n_outliers': n_outliers,
                        'outlier_ratio': n_outliers / len(values),
                        'lower_bound': lower_bound,
                        'upper_bound': upper_bound,
                        'original_range': [np.min(values), np.max(values)]
                    }
                    
                    # Clip outliers
                    values = np.clip(values, lower_bound, upper_bound)
                else:
                    # 如果IQR为0，说明值都相同，不需要clip
                    outlier_stats[column] = {
                        'n_outliers': 0,
                        'outlier_ratio': 0.0,
                        'lower_bound': values[0],
                        'upper_bound': values[0],
                        'original_range': [np.min(values), np.max(values)]
                    }
            
            # Step 2: 应用标准化
            if method == 'standard':
                mean_val = np.mean(values)
                std_val = np.std(values)
                if std_val > 0:
                    normalized_df[column] = (values - mean_val) / std_val
                else:
                    normalized_df[column] = values - mean_val
                    
            elif method == 'robust':
                median_val = np.median(values)
                mad_val = np.median(np.abs(values - median_val))
                if mad_val > 0:
                    normalized_df[column] = (values - median_val) / (1.4826 * mad_val)
                else:
                    normalized_df[column] = values - median_val
                    
            elif method == 'minmax':
                min_val = np.min(values)
                max_val = np.max(values)
                if max_val > min_val:
                    normalized_df[column] = (values - min_val) / (max_val - min_val)
                else:
                    normalized_df[column] = np.zeros_like(values)
                    
            elif method == 'quantile':
                q25 = np.percentile(values, 25)
                q75 = np.percentile(values, 75)
                iqr = q75 - q25
                median_val = np.median(values)
                if iqr > 0:
                    normalized_df[column] = (values - median_val) / iqr
                else:
                    normalized_df[column] = values - median_val
            
            else:
                raise ValueError(f"Unknown normalization method: {method}")
        
        # 可选：存储outlier统计信息以便后续分析
        if outlier_clip and hasattr(self, '_outlier_stats'):
            if not hasattr(self, '_outlier_stats'):
                self._outlier_stats = {}
            self._outlier_stats[f'normalize_{len(self._outlier_stats)}'] = outlier_stats
        
        return normalized_df
    
    def _get_feature_stats(self, features_df):
        """获取特征的统计信息"""
        means = features_df.mean()
        stds = features_df.std()
        
        return {
            'mean_min': means.min(),
            'mean_max': means.max(),
            'std_min': stds.min(),
            'std_max': stds.max(),
            'mean_mean': means.mean(),
            'std_mean': stds.mean()
        }
    
    def prepare_dataset_for_channel_classification(self, use_normalized=True):
        """
        准备用于channel-level分类的数据集
        
        核心修改：
        1. 每个channel的多个时间窗口样本会被分别训练
        2. 在验证时，对同一个channel的所有样本的预测概率取平均
        
        Returns:
        --------
        samples_data : list
            每个sample的信息，包含：
            - features: 特征向量
            - label: 标签
            - channel_id: 所属channel ID
            - patient_id: 所属patient ID
        """
        
        print(f"\n📊 Preparing Dataset for Channel Classification...")
        
        # 选择数据源
        if use_normalized and self.normalized_channel_features_data:
            data_source = self.normalized_channel_features_data
        else:
            data_source = self.channel_features_data
        
        if not data_source:
            print("  ❌ No channel features available.")
            return None
        
        samples_data = []
        
        for channel_id, channel_data in data_source.items():
            # 每个channel的每个时间窗口都作为一个独立的训练样本
            for window_idx, window_features in enumerate(channel_data['features']):
                sample_info = {
                    'features': np.array(list(window_features.values())),
                    'label': channel_data['label'],
                    'channel_id': channel_id,
                    'patient_id': channel_data['patient_id'],
                    'matter_type': channel_data['matter_type'],
                    'window_idx': window_idx
                }
                samples_data.append(sample_info)
        
        print(f"   Total samples: {len(samples_data)}")
        print(f"   From channels: {len(data_source)} channels")
        
        # 统计信息
        grey_samples = sum(1 for s in samples_data if s['label'] == 1)
        white_samples = sum(1 for s in samples_data if s['label'] == 0)
        
        print(f"   Grey matter samples: {grey_samples}")
        print(f"   White matter samples: {white_samples}")
        
        return samples_data
    
    def leave_one_patient_out_validation_channel_level(self, use_normalized=True):
        """
        Channel-level的Leave-one-patient-out交叉验证
        
        核心修改：
        1. 训练时使用所有的时间窗口样本
        2. 验证时计算每个patient的每个channel的平均预测概率
        3. 基于平均概率进行channel-level分类
        4. 计算每个patient的channel-level accuracy
        """
        
        print(f"\n🔄 Channel-Level Leave-One-Patient-Out Validation")
        
        # 准备数据
        samples_data = self.prepare_dataset_for_channel_classification(use_normalized)
        
        if samples_data is None:
            print("❌ Failed to prepare dataset")
            return None, None, None
        
        # 获取患者列表
        patients = list(set([sample['patient_id'] for sample in samples_data]))
        n_patients = len(patients)
        
        if n_patients < 3:
            raise ValueError(f"Need at least 3 patients, {n_patients} patients found")
        
        print(f"   {n_patients} patients, {len(samples_data)} samples total")
        
        # 存储结果
        cv_results = {name: [] for name in self.classifiers.keys()}
        all_predictions = {name: {'y_true': [], 'y_pred': [], 'y_proba': [], 'test_patients': [], 'channel_ids': []} 
                          for name in self.classifiers.keys()}
        
        # Leave-one-patient-out循环
        for fold, test_patient in enumerate(patients):
            train_patients = [p for p in patients if p != test_patient]
            
            print(f"\nFold {fold+1}/{n_patients}: 测试 {test_patient}")
            
            # 分离训练和测试样本
            train_samples = [s for s in samples_data if s['patient_id'] in train_patients]
            test_samples = [s for s in samples_data if s['patient_id'] == test_patient]
            
            # 准备训练数据 (sample-level)
            X_train = np.array([s['features'] for s in train_samples])
            y_train = np.array([s['label'] for s in train_samples])
            
            # 准备测试数据 (sample-level)  
            X_test = np.array([s['features'] for s in test_samples])
            y_test = np.array([s['label'] for s in test_samples])
            
            # 获取测试集的channel信息
            test_channel_ids = [s['channel_id'] for s in test_samples]
            test_channel_labels = [s['label'] for s in test_samples]
            
            print(f"   Training samples: {len(X_train)} ({np.sum(y_train)} grey, {np.sum(y_train==0)} white)")
            print(f"   Testing samples: {len(X_test)} ({np.sum(y_test)} grey, {np.sum(y_test==0)} white)")
            
            # 跨患者标准化（如果使用原始特征）
            if not use_normalized:
                scaler = StandardScaler()
                X_train = scaler.fit_transform(X_train)
                X_test = scaler.transform(X_test)
            
            # 训练和评估每个分类器
            for clf_name, clf in self.classifiers.items():
                try:
                    # 训练分类器 (sample-level)
                    clf.fit(X_train, y_train)
                    
                    # 预测测试集 (sample-level)
                    y_pred_samples = clf.predict(X_test)
                    
                    if hasattr(clf, "predict_proba"):
                        y_proba_samples = clf.predict_proba(X_test)[:, 1]
                    else:
                        y_proba_samples = np.zeros_like(y_pred_samples, dtype=float)
                    
                    # 将sample-level预测聚合为channel-level预测
                    channel_predictions = self._aggregate_predictions_to_channel_level(
                        test_samples, y_pred_samples, y_proba_samples
                    )
                    
                    # 提取channel-level的真实标签和预测
                    channel_y_true = [pred['true_label'] for pred in channel_predictions]
                    channel_y_pred = [pred['pred_label'] for pred in channel_predictions]
                    channel_y_proba = [pred['avg_proba'] for pred in channel_predictions]
                    channel_ids = [pred['channel_id'] for pred in channel_predictions]
                    
                    # 计算channel-level指标
                    fold_results = {
                        'fold': fold,
                        'test_patient': test_patient,
                        'accuracy': accuracy_score(channel_y_true, channel_y_pred),
                        'f1_score': f1_score(channel_y_true, channel_y_pred, zero_division=0),
                        'precision': precision_score(channel_y_true, channel_y_pred, zero_division=0),
                        'recall': recall_score(channel_y_true, channel_y_pred, zero_division=0),
                        'balanced_accuracy': balanced_accuracy_score(channel_y_true, channel_y_pred),
                        'n_test_channels': len(channel_predictions),
                        'n_test_samples': len(y_test)
                    }
                    
                    if len(np.unique(channel_y_true)) > 1:
                        fold_results['roc_auc'] = roc_auc_score(channel_y_true, channel_y_proba)
                    else:
                        fold_results['roc_auc'] = 0.5
                    
                    cv_results[clf_name].append(fold_results)
                    
                    # 存储channel-level预测结果
                    all_predictions[clf_name]['y_true'].extend(channel_y_true)
                    all_predictions[clf_name]['y_pred'].extend(channel_y_pred)
                    all_predictions[clf_name]['y_proba'].extend(channel_y_proba)
                    all_predictions[clf_name]['test_patients'].extend([test_patient] * len(channel_predictions))
                    all_predictions[clf_name]['channel_ids'].extend(channel_ids)
                    
                    print(f"    {clf_name}: Channels - F1={fold_results['f1_score']:.3f}, "
                          f"Acc={fold_results['accuracy']:.3f}, AUC={fold_results['roc_auc']:.3f} "
                          f"({fold_results['n_test_channels']} channels)")
                
                except Exception as e:
                    print(f"    {clf_name}: Error - {e}")
        
        return cv_results, all_predictions, samples_data
    
    def _aggregate_predictions_to_channel_level(self, test_samples, y_pred_samples, y_proba_samples):
        """
        将sample-level的预测聚合为channel-level的预测
        
        对每个channel的所有样本（时间窗口）的预测概率取平均，然后基于平均概率进行分类
        """
        
        # 按channel_id分组
        channel_groups = {}
        for i, sample in enumerate(test_samples):
            channel_id = sample['channel_id']
            if channel_id not in channel_groups:
                channel_groups[channel_id] = {
                    'sample_indices': [],
                    'true_label': sample['label'],  # 同一个channel的所有样本应该有相同的标签
                    'patient_id': sample['patient_id']
                }
            channel_groups[channel_id]['sample_indices'].append(i)
        
        # 为每个channel计算平均预测概率
        channel_predictions = []
        for channel_id, group_info in channel_groups.items():
            sample_indices = group_info['sample_indices']
            
            # 获取这个channel的所有样本的预测概率
            channel_probas = y_proba_samples[sample_indices]
            
            # 计算平均概率
            avg_proba = np.mean(channel_probas)
            
            # 基于平均概率进行分类 (threshold = 0.5)
            pred_label = 1 if avg_proba > 0.5 else 0
            
            channel_predictions.append({
                'channel_id': channel_id,
                'true_label': group_info['true_label'],
                'pred_label': pred_label,
                'avg_proba': avg_proba,
                'patient_id': group_info['patient_id'],
                'n_samples': len(sample_indices)
            })
        
        return channel_predictions
    
    def analyze_channel_level_results(self, cv_results, all_predictions):
        """分析channel-level交叉验证结果"""
        
        print(f"\n{'='*60}")
        print(f"Channel-Level Analysis Results...")
        print(f"{'='*60}")
        
        final_results = {}
        
        for clf_name in self.classifiers.keys():
            if len(cv_results[clf_name]) > 0:
                # CV指标统计
                cv_metrics = {}
                for metric in ['accuracy', 'f1_score', 'precision', 'recall', 'balanced_accuracy', 'roc_auc']:
                    values = [fold[metric] for fold in cv_results[clf_name]]
                    cv_metrics[f'{metric}_mean'] = np.mean(values)
                    cv_metrics[f'{metric}_std'] = np.std(values)
                
                # 整体预测指标
                y_true_all = np.array(all_predictions[clf_name]['y_true'])
                y_pred_all = np.array(all_predictions[clf_name]['y_pred'])
                y_proba_all = np.array(all_predictions[clf_name]['y_proba'])
                
                overall_metrics = {
                    'overall_accuracy': accuracy_score(y_true_all, y_pred_all),
                    'overall_f1': f1_score(y_true_all, y_pred_all),
                    'overall_precision': precision_score(y_true_all, y_pred_all),
                    'overall_recall': recall_score(y_true_all, y_pred_all),
                    'overall_balanced_acc': balanced_accuracy_score(y_true_all, y_pred_all)
                }
                
                if len(np.unique(y_true_all)) > 1:
                    overall_metrics['overall_roc_auc'] = roc_auc_score(y_true_all, y_proba_all)
                else:
                    overall_metrics['overall_roc_auc'] = 0.5
                
                # 合并结果
                final_results[clf_name] = {
                    **cv_metrics,
                    **overall_metrics,
                    'cv_folds': cv_results[clf_name],
                    'predictions': all_predictions[clf_name],
                    'confusion_matrix': confusion_matrix(y_true_all, y_pred_all)
                }
                
                print(f"\n{clf_name}:")
                print(f"  CV F1: {cv_metrics['f1_score_mean']:.3f} ± {cv_metrics['f1_score_std']:.3f}")
                print(f"  Overall F1: {overall_metrics['overall_f1']:.3f}")
                print(f"  Overall Balanced Acc: {overall_metrics['overall_balanced_acc']:.3f}")
                print(f"  Overall ROC AUC: {overall_metrics['overall_roc_auc']:.3f}")
        
        # 找到最佳分类器
        if final_results:
            best_classifier = max(final_results.items(), key=lambda x: x[1]['overall_f1'])
            best_name, best_metrics = best_classifier
            
            print(f"\n{'='*60}")
            print(f"🏆 Best Channel-Level Classifier: {best_name}")
            print(f"   Overall F1 Score: {best_metrics['overall_f1']:.3f}")
            print(f"   CV F1: {best_metrics['f1_score_mean']:.3f} ± {best_metrics['f1_score_std']:.3f}")
            print(f"   Overall Balanced Acc: {best_metrics['overall_balanced_acc']:.3f}")
            print(f"   Overall ROC AUC: {best_metrics['overall_roc_auc']:.3f}")
            print(f"{'='*60}")
        else:
            best_name, best_metrics = None, None
        
        return final_results, best_name, best_metrics
    
    def analyze_patient_level_performance(self, cv_results, all_predictions):
        """
        分析每个patient的channel-level性能
        """
        
        print(f"\n📊 Patient-Level Channel Classification Performance")
        print("=" * 80)
        
        # 为每个分类器分析patient-level性能
        for clf_name in self.classifiers.keys():
            if len(cv_results[clf_name]) == 0:
                continue
                
            print(f"\n{clf_name}:")
            print("-" * 50)
            print(f"{'Patient':<12} {'Channels':<10} {'Accuracy':<10} {'F1':<8} {'Precision':<10} {'Recall':<8}")
            print("-" * 50)
            
            total_channels = 0
            total_correct = 0
            patient_f1_scores = []
            
            for fold_result in cv_results[clf_name]:
                test_patient = fold_result['test_patient']
                n_channels = fold_result['n_test_channels']
                accuracy = fold_result['accuracy']
                f1 = fold_result['f1_score']
                precision = fold_result['precision']
                recall = fold_result['recall']
                
                print(f"{test_patient:<12} {n_channels:<10} {accuracy:<10.3f} {f1:<8.3f} {precision:<10.3f} {recall:<8.3f}")
                
                total_channels += n_channels
                total_correct += int(accuracy * n_channels)
                patient_f1_scores.append(f1)
            
            # 计算总体统计
            overall_accuracy = total_correct / total_channels if total_channels > 0 else 0
            mean_patient_f1 = np.mean(patient_f1_scores) if patient_f1_scores else 0
            std_patient_f1 = np.std(patient_f1_scores) if patient_f1_scores else 0
            
            print("-" * 50)
            print(f"{'Total':<12} {total_channels:<10} {overall_accuracy:<10.3f} {mean_patient_f1:<8.3f}")
            print(f"Patient F1 Std: {std_patient_f1:.3f}")
            
        return True
    
    def create_channel_level_visualization(self, use_normalized=True, cv_results=None, all_predictions=None):
        """创建channel-level的可视化"""
        
        print(f"\n📊 Creating Channel-Level Visualizations...")
        
        # 准备数据
        samples_data = self.prepare_dataset_for_channel_classification(use_normalized)
        
        if samples_data is None:
            print("❌ No data available for visualization")
            return
        
        # 创建患者信息映射
        unique_patients = list(set([s['patient_id'] for s in samples_data]))
        patient_colors = {}
        colors = plt.cm.Set3(np.linspace(0, 1, len(unique_patients)))
        
        for i, patient in enumerate(unique_patients):
            patient_colors[patient] = colors[i]
        
        # 1. 特征分布对比（标准化前后）
        if use_normalized and self.channel_features_data:
            self._create_normalization_effect_plot()
        
        # 2. 患者间channel分布
        self._create_patient_channel_distribution_plot(samples_data, patient_colors)
        
        # 3. 特征重要性分析
        self._create_feature_importance_plot(samples_data)
        
        # 4. 如果有CV结果，创建性能可视化
        if cv_results and all_predictions:
            self._create_performance_visualization(cv_results, all_predictions)
        
        print(f"  💾 Visualizations saved to: {self.output_folder}")
    
    def _create_normalization_effect_plot(self):
        """创建标准化效果对比图"""
        
        print("  📈 Creating normalization effect plots...")
        
        # 选择几个代表性的channels进行对比
        sample_channels = list(self.channel_features_data.keys())[:6]  # 前6个channel
        
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        axes = axes.flatten()
        
        for i, channel_id in enumerate(sample_channels):
            if i >= 6:
                break
                
            # 原始数据
            orig_features = pd.DataFrame(self.channel_features_data[channel_id]['features'])
            
            # 标准化数据
            if channel_id in self.normalized_channel_features_data:
                norm_features = pd.DataFrame(self.normalized_channel_features_data[channel_id]['features'])
            else:
                continue
            
            # 选择一个代表性特征进行可视化
            feature_name = 'rel_power_gamma'  # 或者选择其他特征
            if feature_name in orig_features.columns:
                
                ax = axes[i]
                
                # 绘制原始数据分布
                ax.hist(orig_features[feature_name], alpha=0.6, label='Original', bins=20, color='blue')
                
                # 绘制标准化数据分布
                ax.hist(norm_features[feature_name], alpha=0.6, label='Normalized', bins=20, color='red')
                
                # 设置标签和标题
                patient_id = self.channel_features_data[channel_id]['patient_id']
                matter_type = self.channel_features_data[channel_id]['matter_type']
                
                ax.set_title(f'{patient_id}_{matter_type}\n{feature_name}', fontsize=10)
                ax.set_xlabel('Feature Value')
                ax.set_ylabel('Frequency')
                ax.legend()
                ax.grid(True, alpha=0.3)
        
        # 隐藏多余的子图
        for i in range(len(sample_channels), 6):
            axes[i].set_visible(False)
        
        plt.suptitle('Channel-Level Normalization Effect', fontsize=14)
        plt.tight_layout()
        plt.savefig(os.path.join(self.output_folder, 'channel_normalization_effect.png'), 
                   dpi=300, bbox_inches='tight')
        plt.close()
    
    def _create_patient_channel_distribution_plot(self, samples_data, patient_colors):
        """创建患者间channel分布图"""
        
        print("  📊 Creating patient channel distribution plot...")
        
        # 统计每个患者的channel数量
        patient_stats = {}
        channel_counts = {}  # 统计unique channels
        
        for sample in samples_data:
            patient_id = sample['patient_id']
            matter_type = sample['matter_type']
            channel_id = sample['channel_id']
            
            if patient_id not in patient_stats:
                patient_stats[patient_id] = {'grey': set(), 'white': set()}
            
            patient_stats[patient_id][matter_type].add(channel_id)
            
        # 转换为计数
        for patient_id in patient_stats:
            patient_stats[patient_id]['grey'] = len(patient_stats[patient_id]['grey'])
            patient_stats[patient_id]['white'] = len(patient_stats[patient_id]['white'])
        
        # 创建图表
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
        
        # 1. 每个患者的channel分布
        patients = list(patient_stats.keys())
        grey_counts = [patient_stats[p]['grey'] for p in patients]
        white_counts = [patient_stats[p]['white'] for p in patients]
        
        x = np.arange(len(patients))
        width = 0.35
        
        ax1.bar(x - width/2, grey_counts, width, label='Grey Matter', alpha=0.8, color='red')
        ax1.bar(x + width/2, white_counts, width, label='White Matter', alpha=0.8, color='blue')
        
        ax1.set_xlabel('Patient ID')
        ax1.set_ylabel('Number of Channels')
        ax1.set_title('Channel Distribution by Patient')
        ax1.set_xticks(x)
        ax1.set_xticklabels(patients, rotation=45)
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # 2. 总体分布饼图
        total_grey = sum(grey_counts)
        total_white = sum(white_counts)
        
        ax2.pie([total_grey, total_white], 
               labels=['Grey Matter', 'White Matter'],
               colors=['red', 'blue'],
               autopct='%1.1f%%')
        ax2.set_title(f'Overall Channel Distribution\n(Total: {total_grey + total_white} channels)')
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.output_folder, 'patient_channel_distribution.png'), 
                   dpi=300, bbox_inches='tight')
        plt.close()
    
    def _create_patient_feature_visualization(self, samples_data, selected_features=None, max_patients_per_plot=8):
        """
        为每个feature创建单独的可视化图，每个图包含所有患者的该feature分布
        
        Parameters:
        -----------
        samples_data : list
            样本数据列表
        selected_features : list, optional
            要可视化的特征列表，如果为None则使用所有特征
        max_patients_per_plot : int
            每个图最多显示的患者数量
        """
        
        print("  📊 Creating per-feature patient visualization...")
        
        # 获取特征名称
        if not samples_data:
            print("    ❌ No samples data available")
            return
            
        # 从第一个sample获取特征名称
        first_channel = list(self.channel_features_data.values())[0]
        all_feature_names = list(first_channel['features'][0].keys())
        
        if selected_features is None:
            # 选择一些代表性特征进行可视化
            selected_features = [
                'std', 'rms', 'area',
                'rel_power_delta', 'rel_power_theta', 'rel_power_alpha', 
                'rel_power_beta', 'rel_power_gamma', 'rel_power_high_gamma',
                'range', 'mad'
            ]
            # 只保留实际存在的特征
            selected_features = [f for f in selected_features if f in all_feature_names]
        
        # 组织数据：按patient和matter_type分组
        patient_feature_data = {}
        feature_names = list(samples_data[0]['features'].keys()) if samples_data else []
        
        for sample in samples_data:
            patient_id = sample['patient_id']
            matter_type = sample['matter_type']
            
            if patient_id not in patient_feature_data:
                patient_feature_data[patient_id] = {'grey': [], 'white': []}
            
            # 将features转换为dict格式（如果是array）
            if isinstance(sample['features'], np.ndarray):
                feature_dict = {name: sample['features'][i] for i, name in enumerate(feature_names)}
            else:
                feature_dict = sample['features']
            
            patient_feature_data[patient_id][matter_type].append(feature_dict)
        
        # 获取患者列表并排序
        patients = sorted(list(patient_feature_data.keys()))
        
        # 如果患者太多，分批处理
        patient_batches = [patients[i:i+max_patients_per_plot] 
                          for i in range(0, len(patients), max_patients_per_plot)]
        
        # 为每个选定的特征创建可视化
        for feature_name in selected_features:
            print(f"    Creating visualization for feature: {feature_name}")
            
            for batch_idx, patient_batch in enumerate(patient_batches):
                # 计算subplot布局
                n_patients = len(patient_batch)
                n_cols = min(4, n_patients)  # 最多4列
                n_rows = (n_patients + n_cols - 1) // n_cols  # 向上取整
                
                # 创建figure
                fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 4*n_rows))
                
                # 确保axes是2D数组
                if n_rows == 1 and n_cols == 1:
                    axes = np.array([[axes]])
                elif n_rows == 1:
                    axes = axes.reshape(1, -1)
                elif n_cols == 1:
                    axes = axes.reshape(-1, 1)
                
                # 为每个患者创建subplot
                for i, patient_id in enumerate(patient_batch):
                    row = i // n_cols
                    col = i % n_cols
                    ax = axes[row, col]
                    
                    # 提取该患者该特征的数据
                    grey_values = []
                    white_values = []
                    
                    for sample_dict in patient_feature_data[patient_id]['grey']:
                        if feature_name in sample_dict:
                            grey_values.append(sample_dict[feature_name])
                    
                    for sample_dict in patient_feature_data[patient_id]['white']:
                        if feature_name in sample_dict:
                            white_values.append(sample_dict[feature_name])
                    
                    # 绘制分布
                    if grey_values:
                        ax.hist(grey_values, bins=20, alpha=0.6, label='Grey', color='red', density=True)
                    if white_values:
                        ax.hist(white_values, bins=20, alpha=0.6, label='White', color='blue', density=True)
                    
                    # 设置标题和标签
                    ax.set_title(f'{patient_id}\n(G:{len(grey_values)}, W:{len(white_values)})', fontsize=10)
                    ax.set_xlabel(f'{feature_name}')
                    ax.set_ylabel('Density')
                    ax.legend(fontsize=8)
                    ax.grid(True, alpha=0.3)
                    
                    # 设置合适的x轴范围
                    all_values = grey_values + white_values
                    if all_values:
                        ax.set_xlim(np.percentile(all_values, [1, 99]))
                
                # 隐藏多余的subplots
                for i in range(len(patient_batch), n_rows * n_cols):
                    row = i // n_cols
                    col = i % n_cols
                    axes[row, col].set_visible(False)
                
                # 设置整体标题
                batch_suffix = f"_batch{batch_idx+1}" if len(patient_batches) > 1 else ""
                fig.suptitle(f'Feature Distribution: {feature_name}{batch_suffix}', fontsize=16, y=0.98)
                
                plt.tight_layout()
                plt.subplots_adjust(top=0.93)  # 为suptitle留出空间
                
                # 保存图片
                safe_feature_name = feature_name.replace('/', '_').replace('\\', '_')
                filename = f'feature_{safe_feature_name}{batch_suffix}.png'
                plt.savefig(os.path.join(self.output_folder, filename), 
                           dpi=300, bbox_inches='tight')
                plt.close()
        
        print(f"    ✅ Created visualizations for {len(selected_features)} features")
        return True
    
    def _create_feature_importance_plot(self, samples_data):
        """创建特征重要性图"""
        
        print("  🎯 Creating feature importance plot...")
        
        # 准备数据
        X = np.array([s['features'] for s in samples_data])
        y = np.array([s['label'] for s in samples_data])
        
        # 使用随机森林计算特征重要性
        from sklearn.ensemble import RandomForestClassifier
        
        rf = RandomForestClassifier(n_estimators=100, random_state=42)
        rf.fit(X, y)
        
        # 获取特征重要性
        importances = rf.feature_importances_
        
        # 获取特征名称
        first_channel = list(self.channel_features_data.values())[0]
        feature_names = list(first_channel['features'][0].keys())
        
        # 排序
        indices = np.argsort(importances)[::-1]
        
        # 绘制前20个最重要的特征
        n_features = min(20, len(importances))
        
        plt.figure(figsize=(12, 8))
        plt.barh(range(n_features), importances[indices[:n_features]], alpha=0.8)
        plt.yticks(range(n_features), [feature_names[i] for i in indices[:n_features]])
        plt.xlabel('Feature Importance')
        plt.title('Top 20 Feature Importances (Random Forest)')
        plt.gca().invert_yaxis()
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.savefig(os.path.join(self.output_folder, 'feature_importance.png'), 
                   dpi=300, bbox_inches='tight')
        plt.close()
    
    def _create_performance_visualization(self, cv_results, all_predictions):
        """创建性能可视化图表"""
        
        print("  📈 Creating performance visualization...")
        
        # 1. 分类器性能对比
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
        
        # 准备数据
        classifiers = []
        f1_means = []
        f1_stds = []
        accuracies = []
        aucs = []
        
        for clf_name in self.classifiers.keys():
            if len(cv_results[clf_name]) > 0:
                classifiers.append(clf_name)
                
                f1_values = [fold['f1_score'] for fold in cv_results[clf_name]]
                f1_means.append(np.mean(f1_values))
                f1_stds.append(np.std(f1_values))
                
                acc_values = [fold['accuracy'] for fold in cv_results[clf_name]]
                accuracies.append(np.mean(acc_values))
                
                auc_values = [fold['roc_auc'] for fold in cv_results[clf_name]]
                aucs.append(np.mean(auc_values))
        
        # F1 Score with error bars
        x_pos = np.arange(len(classifiers))
        ax1.bar(x_pos, f1_means, yerr=f1_stds, capsize=5, alpha=0.8)
        ax1.set_xlabel('Classifiers')
        ax1.set_ylabel('F1 Score')
        ax1.set_title('F1 Score by Classifier (with std)')
        ax1.set_xticks(x_pos)
        ax1.set_xticklabels(classifiers, rotation=45)
        ax1.grid(True, alpha=0.3)
        
        # Accuracy comparison
        ax2.bar(x_pos, accuracies, alpha=0.8, color='orange')
        ax2.set_xlabel('Classifiers')
        ax2.set_ylabel('Accuracy')
        ax2.set_title('Accuracy by Classifier')
        ax2.set_xticks(x_pos)
        ax2.set_xticklabels(classifiers, rotation=45)
        ax2.grid(True, alpha=0.3)
        
        # ROC AUC comparison
        ax3.bar(x_pos, aucs, alpha=0.8, color='green')
        ax3.set_xlabel('Classifiers')
        ax3.set_ylabel('ROC AUC')
        ax3.set_title('ROC AUC by Classifier')
        ax3.set_xticks(x_pos)
        ax3.set_xticklabels(classifiers, rotation=45)
        ax3.grid(True, alpha=0.3)
        
        # Patient-wise performance for best classifier
        if f1_means:
            best_clf_idx = np.argmax(f1_means)
            best_clf_name = classifiers[best_clf_idx]
            
            patient_f1s = [fold['f1_score'] for fold in cv_results[best_clf_name]]
            patient_names = [fold['test_patient'] for fold in cv_results[best_clf_name]]
            
            ax4.bar(range(len(patient_f1s)), patient_f1s, alpha=0.8, color='purple')
            ax4.set_xlabel('Patients')
            ax4.set_ylabel('F1 Score')
            ax4.set_title(f'Patient-wise F1 Score ({best_clf_name})')
            ax4.set_xticks(range(len(patient_names)))
            ax4.set_xticklabels(patient_names, rotation=45)
            ax4.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.output_folder, 'performance_comparison.png'), 
                   dpi=300, bbox_inches='tight')
        plt.close()
    
    def run_complete_channel_level_analysis(self, 
                                          use_windowing=True,
                                          window_size_ms=500,
                                          step_size_ms=250,
                                          max_windows_per_channel=200,
                                          normalization_method='robust'):
        """
        运行完整的channel-level分析
        """
        
        print(f"🧠 Channel-Level Grey/White Matter Classification Analysis")
        print(f"=" * 80)
        
        # 1. 加载患者数据
        if not self.load_all_patients():
            print("❌ Loading patient data failed")
            return None
        
        # 2. 提取channel-centric特征
        if not self.extract_channel_centric_features(
            use_windowing=use_windowing,
            window_size_ms=window_size_ms,
            step_size_ms=step_size_ms,
            max_windows_per_channel=max_windows_per_channel
        ):
            print("❌ Channel feature extraction failed")
            return None
        
        # 3. 应用channel-specific标准化
        if not self.apply_channel_specific_normalization(normalization_method):
            print("❌ Channel-specific normalization failed")
            return None
        
        # 4. 主要分析：channel-level分类
        cv_results, all_predictions, samples_data = self.leave_one_patient_out_validation_channel_level(
            use_normalized=True
        )
        
        final_results, best_name, best_metrics = self.analyze_channel_level_results(cv_results, all_predictions)
        
        # 5. 分析patient-level性能
        self.analyze_patient_level_performance(cv_results, all_predictions)
        
        # 6. 创建可视化
        self.create_channel_level_visualization(use_normalized=True, cv_results=cv_results, all_predictions=all_predictions)
        
        # 7. 保存结果
        self._save_channel_level_results(final_results, best_name, normalization_method, cv_results)
        
        # 8. 汇总输出
        print(f"\n{'='*80}")
        print(f"Channel-Level Analysis Complete")
        print(f"{'='*80}")
        
        n_channels = len(self.channel_features_data)
        n_grey = sum(1 for ch in self.channel_features_data.values() if ch['label'] == 1)
        n_white = sum(1 for ch in self.channel_features_data.values() if ch['label'] == 0)
        n_patients = len(set([ch['patient_id'] for ch in self.channel_features_data.values()]))
        
        print(f"Number of Patients: {n_patients}")
        print(f"Total Channels: {n_channels}")
        print(f"Grey Matter Channels: {n_grey}")
        print(f"White Matter Channels: {n_white}")
        print(f"Normalization Method: {normalization_method}")
        
        if best_name and best_metrics:
            print(f"\n🏆 Best Channel-Level Classifier: {best_name}")
            print(f"   Overall F1 Score: {best_metrics['overall_f1']:.3f}")
            print(f"   CV F1: {best_metrics['f1_score_mean']:.3f} ± {best_metrics['f1_score_std']:.3f}")
            print(f"   Overall Balanced Acc: {best_metrics['overall_balanced_acc']:.3f}")
            print(f"   Overall ROC AUC: {best_metrics['overall_roc_auc']:.3f}")
        
        print(f"\n📁 Results saved to: {self.output_folder}")
        
        return {
            'final_results': final_results,
            'best_classifier': best_name,
            'best_metrics': best_metrics,
            'samples_data': samples_data,
            'cv_results': cv_results,
            'all_predictions': all_predictions,
            'n_channels': n_channels,
            'n_patients': n_patients
        }
    
    def _save_channel_level_results(self, final_results, best_name, normalization_method, cv_results):
        """保存channel-level结果"""
        
        # 保存分类结果
        if final_results:
            summary_data = []
            for clf_name, metrics in final_results.items():
                summary_data.append({
                    'Classifier': clf_name,
                    'CV_F1_Mean': metrics['f1_score_mean'],
                    'CV_F1_Std': metrics['f1_score_std'],
                    'Overall_F1': metrics['overall_f1'],
                    'Overall_Balanced_Acc': metrics['overall_balanced_acc'],
                    'Overall_ROC_AUC': metrics['overall_roc_auc'],
                    'Normalization_Method': normalization_method,
                    'Analysis_Level': 'channel'
                })
            
            summary_df = pd.DataFrame(summary_data)
            summary_df.to_csv(
                os.path.join(self.output_folder, f'channel_level_classification_summary_{normalization_method}.csv'), 
                index=False
            )
        
        # 保存patient-level详细结果
        if cv_results:
            patient_results = []
            for clf_name, folds in cv_results.items():
                for fold in folds:
                    patient_results.append({
                        'Classifier': clf_name,
                        'Patient': fold['test_patient'],
                        'Fold': fold['fold'],
                        'N_Channels': fold['n_test_channels'],
                        'N_Samples': fold['n_test_samples'],
                        'Accuracy': fold['accuracy'],
                        'F1_Score': fold['f1_score'],
                        'Precision': fold['precision'],
                        'Recall': fold['recall'],
                        'Balanced_Accuracy': fold['balanced_accuracy'],
                        'ROC_AUC': fold['roc_auc']
                    })
            
            patient_df = pd.DataFrame(patient_results)
            patient_df.to_csv(
                os.path.join(self.output_folder, f'patient_level_channel_results_{normalization_method}.csv'),
                index=False
            )
        
        print(f"  💾 Channel-level results saved")


def compare_windowing_strategies(processed_folder, output_base='comparison_results'):
    """比较不同窗口策略的性能"""
    
    strategies = [
        {'use_windowing': True, 'window_size_ms': 10000, 'step_size_ms': 5000, 'name': 'window_1000ms'},
        {'use_windowing': True, 'window_size_ms': 5000, 'step_size_ms': 2500, 'name': 'window_500ms'},
        {'use_windowing': True, 'window_size_ms': 2500, 'step_size_ms': 1250, 'name': 'window_250ms'}
    ]
    
    comparison_results = {}
    
    for strategy in strategies:
        strategy_name = strategy.pop('name')
        output_folder = f"{output_base}_{strategy_name}"
        
        print(f"\nStrategy Testing: {strategy_name}")
        print("=" * 40)
        
        classifier = MultiPatientClassifier(processed_folder, output_folder)
        results = classifier.run_complete_analysis(**strategy)
        
        if results:
            comparison_results[strategy_name] = {
                'best_classifier': results['best_classifier'],
                'best_f1': results['best_metrics']['overall_f1'],
                'best_balanced_acc': results['best_metrics']['overall_balanced_acc'],
                'n_features': results['data_summary']['n_features'],
                'total_samples': results['data_summary']['total_samples']
            }
    
    # 创建对比报告
    if comparison_results:
        comparison_df = pd.DataFrame(comparison_results).T
        comparison_df.to_csv(os.path.join(output_base, 'strategy_comparison.csv'))
        
        print(f"\n策略对比结果:")
        print(comparison_df.to_string())
        
        # 找到最佳策略
        best_strategy = comparison_df['best_f1'].idxmax()
        print(f"\n🏆 最佳策略: {best_strategy}")
        print(f"   F1分数: {comparison_df.loc[best_strategy, 'best_f1']:.3f}")
    
    return comparison_results

In [None]:
# 设置路径
processed_folder = r"D:\BlcRepo\LabCode\SeizureProp\data\gwbaseline_test"
output_folder = r"D:\BlcRepo\LabCode\SeizureProp\result\multi_patient_results_alt"

print("多Patient Grey/White Matter分类系统")
print("=" * 60)

# 创建分类器实例
classifier = MultiPatientClassifier(
    processed_folder=processed_folder,
    output_folder=output_folder
)

# 运行完整分析
results = classifier.run_complete_channel_level_analysis(
    use_windowing=True,
    window_size_ms=10000,
    step_size_ms=5000,
    max_windows_per_channel=200,
    normalization_method='robust'
)

classifier._create_patient_feature_visualization(samples_data=results['samples_data'])

In [None]:
processed_folder = r"D:\BlcRepo\LabCode\SeizureProp\data\gwbaseline_alt"
output_folder = r"D:\BlcRepo\LabCode\SeizureProp\result\multi_patient_results"

results_2 = compare_windowing_strategies(processed_folder, output_base=r'D:\BlcRepo\LabCode\SeizureProp\result\gw_comparison_results')

In [None]:
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
warnings.filterwarnings('ignore')

class PatientFeatureVisualizer:
    """
    可视化每个患者的特征分布，用于分析为什么grey/white matter分类性能不好
    """
    
    def __init__(self, processed_folder, output_folder='feature_visualization'):
        """
        初始化可视化器
        
        Parameters:
        -----------
        processed_folder : str
            包含预处理pkl文件的文件夹路径
        output_folder : str
            可视化结果输出文件夹
        """
        self.processed_folder = processed_folder
        self.output_folder = output_folder
        self.patients_data = {}
        self.combined_features = None
        self.feature_names = None
        
        # 创建输出文件夹
        os.makedirs(output_folder, exist_ok=True)
        
        # 设置颜色调色板
        self.patient_colors = plt.cm.Set3(np.linspace(0, 1, 20))  # 支持最多20个患者
        
        print(f"Initializing Patient Feature Visualizer")
        print(f"Processed folder: {processed_folder}")
        print(f"Output folder: {output_folder}")
    
    def load_and_extract_features(self):
        """加载所有患者数据并提取特征"""
        
        pkl_files = glob.glob(os.path.join(self.processed_folder, "P*_processed.pkl"))
        print(f"\nFound {len(pkl_files)} patient files")
        
        all_patient_features = []
        
        for pkl_file in pkl_files:
            try:
                with open(pkl_file, 'rb') as f:
                    data = pickle.load(f)
                
                pid = data['patient_id']
                self.patients_data[pid] = data
                
                # 提取特征（重用之前的代码逻辑）
                patient_features = self._extract_patient_features(pid, data)
                
                if patient_features is not None:
                    all_patient_features.append(patient_features)
                    print(f"✓ {pid}: {patient_features['n_samples']} samples")
                else:
                    print(f"✗ {pid}: Feature extraction failed")
                    
            except Exception as e:
                print(f"✗ Failed to load {os.path.basename(pkl_file)}: {e}")
        
        # 合并所有患者的特征
        if all_patient_features:
            self._combine_all_features(all_patient_features)
            print(f"\nSuccessfully processed {len(all_patient_features)} patients")
            print(f"Total samples: {len(self.combined_features)}")
            print(f"Total features: {len(self.feature_names)}")
        else:
            print("No valid patient features extracted")
        
        return len(all_patient_features) > 0
    
    def _extract_patient_features(self, patient_id, patient_data):
        """为单个患者提取特征（从MultiPatientClassifier适配）"""
        
        matter_data = patient_data['matter_data']
        recordings = patient_data['recordings']
        
        # 提取电极分类
        try:
            grey_indices, white_indices, classification_info = self._extract_electrode_classification(matter_data)
        except Exception as e:
            print(f"  ✗ {patient_id} electrode classification failed: {e}")
            return None
        
        if len(grey_indices) == 0 or len(white_indices) == 0:
            print(f"  ✗ {patient_id} missing grey or white matter electrodes")
            return None
        
        # 合并所有recordings的数据
        all_grey_data = []
        all_white_data = []
        fs = recordings[0]['sampling_rate'] if recordings else 512
        
        for recording in recordings:
            neural_data = recording['neural_data_processed']
            grey_data = neural_data[:, grey_indices]
            white_data = neural_data[:, white_indices]
            all_grey_data.append(grey_data)
            all_white_data.append(white_data)
        
        combined_grey = np.vstack(all_grey_data) if all_grey_data else np.array([])
        combined_white = np.vstack(all_white_data) if all_white_data else np.array([])
        
        # 创建窗口样本
        all_samples = []
        
        # Grey matter samples
        grey_samples = self._create_windowed_samples(
            combined_grey, grey_indices, 1, patient_id, fs=fs
        )
        
        # White matter samples  
        white_samples = self._create_windowed_samples(
            combined_white, white_indices, 0, patient_id, fs=fs
        )
        
        all_samples = grey_samples + white_samples
        
        if len(all_samples) == 0:
            return None
        
        # 转换为DataFrame
        features_df = pd.DataFrame(all_samples)
        
        # 分离特征和标签
        meta_columns = ['patient_id', 'electrode_idx', 'channel_idx', 'window_start', 'label']
        feature_columns = [col for col in features_df.columns if col not in meta_columns]
        
        # 数据清理
        features_df[feature_columns] = features_df[feature_columns].fillna(0)
        features_df[feature_columns] = features_df[feature_columns].replace([np.inf, -np.inf], 0)
        
        return {
            'patient_id': patient_id,
            'features': features_df[feature_columns],
            'labels': features_df['label'].values,
            'meta': features_df[meta_columns],
            'n_samples': len(features_df),
            'n_grey_samples': np.sum(features_df['label'] == 1),
            'n_white_samples': np.sum(features_df['label'] == 0)
        }
    
    def _extract_electrode_classification(self, matter_data):
        """提取电极分类（从MultiPatientClassifier适配）"""
        
        matter_columns = ['MatterType', 'matter', 'Matter', 'mattertype', 'tissue_type', 'type']
        matter_col = None
        
        for col in matter_columns:
            if col in matter_data.columns:
                matter_col = col
                break
        
        if matter_col is None:
            raise ValueError(f"Matter type column not found. Available: {matter_data.columns.tolist()}")
        
        matter_values = matter_data[matter_col].astype(str).str.lower()
        
        grey_mask = matter_values.isin(['G', 'g', 'Grey', 'grey', 'Gray', 'gray'])
        white_mask = matter_values.isin(['W', 'w', 'White', 'white'])
        
        # 如果G/W格式没找到，尝试包含匹配
        if np.sum(grey_mask) == 0 or np.sum(white_mask) == 0:
            matter_values_lower = matter_values.str.lower()
            
            if np.sum(grey_mask) == 0:
                grey_patterns = ['grey', 'gray', 'cortex', 'cortical']
                grey_mask = matter_values_lower.str.contains('|'.join(grey_patterns), na=False, case=False)
            
            if np.sum(white_mask) == 0:
                white_patterns = ['white']
                white_mask = matter_values_lower.str.contains('|'.join(white_patterns), na=False, case=False)
        
        grey_indices = matter_data.index[grey_mask].tolist()
        white_indices = matter_data.index[white_mask].tolist()
        
        classification_info = {
            'matter_column': matter_col,
            'total_electrodes': len(matter_data),
            'grey_electrodes': len(grey_indices),
            'white_electrodes': len(white_indices),
            'grey_indices': grey_indices,
            'white_indices': white_indices
        }
        
        return grey_indices, white_indices, classification_info
    
    def _create_windowed_samples(self, data, electrode_indices, label, patient_id, 
                                window_size_ms=50000, step_size_ms=25000, max_windows_per_channel=200, fs=512):
        """创建时间窗口样本"""
        
        samples = []
        window_samples = int(window_size_ms * fs / 1000)
        step_samples = int(step_size_ms * fs / 1000)
        
        n_time, n_channels = data.shape
        
        for ch_idx, electrode_idx in enumerate(electrode_indices):
            channel_data = data[:, ch_idx]
            channel_windows = []
            
            for start in range(0, n_time - window_samples + 1, step_samples):
                end = start + window_samples
                window_data = channel_data[start:end]
                
                features = self._extract_signal_features(window_data, fs)
                if features is not None:
                    features['patient_id'] = patient_id
                    features['electrode_idx'] = electrode_idx
                    features['channel_idx'] = ch_idx
                    features['window_start'] = start
                    features['label'] = label
                    
                    channel_windows.append(features)
                
                if len(channel_windows) >= max_windows_per_channel:
                    break
            
            # 随机采样限制窗口数
            if len(channel_windows) > max_windows_per_channel:
                indices = np.random.choice(len(channel_windows), max_windows_per_channel, replace=False)
                channel_windows = [channel_windows[i] for i in sorted(indices)]
            
            samples.extend(channel_windows)
        
        return samples
    
    def _extract_signal_features(self, signal, fs):
        """提取单个信号的特征（从MultiPatientClassifier适配）"""
        
        if len(signal) < 10 or np.all(signal == 0):
            return None
        
        features = {}
        
        # 时域特征
        features['mean'] = np.mean(signal)
        features['std'] = np.std(signal)
        features['var'] = np.var(signal)
        features['median'] = np.median(signal)
        features['mad'] = np.median(np.abs(signal - np.median(signal)))
        features['range'] = np.ptp(signal)
        features['iqr'] = np.percentile(signal, 75) - np.percentile(signal, 25)
        features['rms'] = np.sqrt(np.mean(signal**2))
        
        # 统计矩
        try:
            features['skewness'] = stats.skew(signal)
            features['kurtosis'] = stats.kurtosis(signal)
        except:
            features['skewness'] = 0
            features['kurtosis'] = 0
        
        # 信号复杂度
        features['zero_crossings'] = np.sum(np.diff(np.signbit(signal)))
        features['line_length'] = np.sum(np.abs(np.diff(signal)))
        features['area'] = np.sum(np.abs(signal))
        features['energy'] = np.sum(signal**2)
        
        # 频域特征（简化版本，避免错误）
        try:
            from scipy.fftpack import fft
            n_fft = len(signal)
            windowed_signal = signal * np.hamming(n_fft)
            fft_vals = fft(windowed_signal)
            fft_mag = np.abs(fft_vals[:n_fft//2])
            freqs = np.fft.fftfreq(n_fft, 1/fs)[:n_fft//2]
            
            # 频带功率
            bands = {
                'delta': (0.5, 4),
                'theta': (4, 8),
                'alpha': (8, 13),
                'beta': (13, 30),
                'gamma': (30, 100),
                'high_gamma': (100, min(200, fs/2))
            }
            
            total_power = np.sum(fft_mag**2)
            for band_name, (low, high) in bands.items():
                band_mask = (freqs >= low) & (freqs <= high)
                if np.any(band_mask):
                    band_power = np.sum(fft_mag[band_mask]**2)
                    features[f'power_{band_name}'] = band_power
                    features[f'rel_power_{band_name}'] = band_power / total_power if total_power > 0 else 0
                else:
                    features[f'power_{band_name}'] = 0
                    features[f'rel_power_{band_name}'] = 0
            
            features['total_power'] = total_power
            
        except Exception as e:
            # 如果频域分析失败，设置默认值
            bands = ['delta', 'theta', 'alpha', 'beta', 'gamma', 'high_gamma']
            for band in bands:
                features[f'power_{band}'] = 0
                features[f'rel_power_{band}'] = 0
            features['total_power'] = 0
        
        return features
    
    def _combine_all_features(self, all_patient_features):
        """合并所有患者的特征"""
        
        # 找到共同的特征列
        all_feature_columns = []
        for patient_data in all_patient_features:
            all_feature_columns.append(set(patient_data['features'].columns))
        
        common_features = set.intersection(*all_feature_columns)
        self.feature_names = list(common_features)
        
        # 合并数据
        combined_data = []
        for patient_data in all_patient_features:
            patient_id = patient_data['patient_id']
            features = patient_data['features'][self.feature_names]
            labels = patient_data['labels']
            
            # 添加患者ID到每一行
            for i in range(len(features)):
                row = {
                    'patient_id': patient_id,
                    'label': labels[i],
                    'matter_type': 'Grey' if labels[i] == 1 else 'White'
                }
                # 添加特征值
                for feature_name in self.feature_names:
                    row[feature_name] = features.iloc[i][feature_name]
                
                combined_data.append(row)
        
        self.combined_features = pd.DataFrame(combined_data)
        
        # 数据清理
        feature_cols = [col for col in self.combined_features.columns 
                       if col not in ['patient_id', 'label', 'matter_type']]
        self.combined_features[feature_cols] = self.combined_features[feature_cols].fillna(0)
        self.combined_features[feature_cols] = self.combined_features[feature_cols].replace([np.inf, -np.inf], 0)
    
    def create_individual_feature_plots(self, n_features_per_page=9):
        """为每个特征创建患者分布的散点图"""
        
        if self.combined_features is None:
            print("No features loaded. Run load_and_extract_features() first.")
            return
        
        print(f"\nCreating individual feature distribution plots...")
        
        # 获取患者列表和颜色映射
        patients = self.combined_features['patient_id'].unique()
        patient_color_map = {patient: self.patient_colors[i % len(self.patient_colors)] 
                           for i, patient in enumerate(patients)}
        
        # 分页处理特征
        n_pages = (len(self.feature_names) + n_features_per_page - 1) // n_features_per_page
        
        for page in range(n_pages):
            start_idx = page * n_features_per_page
            end_idx = min(start_idx + n_features_per_page, len(self.feature_names))
            page_features = self.feature_names[start_idx:end_idx]
            
            # 计算子图布局
            n_features_in_page = len(page_features)
            n_cols = 3
            n_rows = (n_features_in_page + n_cols - 1) // n_cols
            
            fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5*n_rows))
            if n_rows == 1:
                axes = axes.reshape(1, -1)
            if n_features_in_page == 1:
                axes = np.array([[axes]])
            
            fig.suptitle(f'Feature Distribution by Patient - Page {page+1}/{n_pages}', 
                        fontsize=16, y=0.98)
            
            for i, feature_name in enumerate(page_features):
                row = i // n_cols
                col = i % n_cols
                ax = axes[row, col]
                
                # 为每个患者创建散点图
                for j, patient in enumerate(patients):
                    patient_data = self.combined_features[self.combined_features['patient_id'] == patient]
                    
                    # Grey matter points
                    grey_data = patient_data[patient_data['matter_type'] == 'Grey']
                    if len(grey_data) > 0:
                        y_grey = np.random.normal(1, 0.1, len(grey_data))  # 添加垂直抖动
                        ax.scatter(grey_data[feature_name], y_grey, 
                                 c=[patient_color_map[patient]], alpha=0.6, s=20,
                                 label=f'{patient}_Grey' if j == 0 else "")
                    
                    # White matter points
                    white_data = patient_data[patient_data['matter_type'] == 'White']
                    if len(white_data) > 0:
                        y_white = np.random.normal(0, 0.1, len(white_data))  # 添加垂直抖动
                        ax.scatter(white_data[feature_name], y_white, 
                                 c=[patient_color_map[patient]], alpha=0.6, s=20, 
                                 marker='s', label=f'{patient}_White' if j == 0 else "")
                
                ax.set_xlabel(feature_name)
                ax.set_ylabel('Matter Type')
                ax.set_yticks([0, 1])
                ax.set_yticklabels(['White', 'Grey'])
                ax.grid(True, alpha=0.3)
                ax.set_title(f'{feature_name}', fontsize=10)
                
                # 添加箱线图层覆盖显示分布
                grey_values = self.combined_features[self.combined_features['matter_type'] == 'Grey'][feature_name]
                white_values = self.combined_features[self.combined_features['matter_type'] == 'White'][feature_name]
                
                if len(grey_values) > 0 and len(white_values) > 0:
                    # 计算简单的分布统计
                    grey_median = np.median(grey_values)
                    white_median = np.median(white_values)
                    
                    ax.axvline(grey_median, ymax=1.25, color='red', linestyle='--', alpha=0.7, linewidth=2)
                    ax.axvline(white_median, ymax=0.25, color='blue', linestyle='--', alpha=0.7, linewidth=2)
            
            # 隐藏空的子图
            for i in range(n_features_in_page, n_rows * n_cols):
                row = i // n_cols
                col = i % n_cols
                axes[row, col].set_visible(False)
            
            # 添加患者颜色图例
            handles = []
            labels = []
            for patient in patients:
                handles.append(plt.Line2D([0], [0], marker='o', color='w', 
                                        markerfacecolor=patient_color_map[patient], markersize=8))
                labels.append(patient)
            
            fig.legend(handles, labels, loc='center', bbox_to_anchor=(0.5, 0.02), 
                      ncol=min(6, len(patients)), fontsize=10)
            
            plt.tight_layout()
            plt.subplots_adjust(bottom=0.15, top=0.92)
            plt.savefig(os.path.join(self.output_folder, f'feature_distributions_page_{page+1}.png'), 
                       dpi=300, bbox_inches='tight')
            plt.close()
            
            print(f"  ✓ Page {page+1}/{n_pages} saved")
    
    def create_feature_summary_statistics(self):
        """创建特征统计汇总表"""
        
        if self.combined_features is None:
            print("No features loaded. Run load_and_extract_features() first.")
            return
        
        print(f"\nCreating feature summary statistics...")
        
        summary_stats = []
        
        for feature_name in self.feature_names:
            # 整体统计
            overall_mean = self.combined_features[feature_name].mean()
            overall_std = self.combined_features[feature_name].std()
            
            # 按matter type分组统计
            grey_data = self.combined_features[self.combined_features['matter_type'] == 'Grey'][feature_name]
            white_data = self.combined_features[self.combined_features['matter_type'] == 'White'][feature_name]
            
            grey_mean = grey_data.mean()
            grey_std = grey_data.std()
            white_mean = white_data.mean()
            white_std = white_data.std()
            
            # 统计检验
            try:
                t_stat, p_value = stats.ttest_ind(grey_data, white_data)
            except:
                t_stat, p_value = 0, 1
            
            # 效应量 (Cohen's d)
            pooled_std = np.sqrt(((len(grey_data) - 1) * grey_std**2 + 
                                 (len(white_data) - 1) * white_std**2) / 
                                (len(grey_data) + len(white_data) - 2))
            cohens_d = (grey_mean - white_mean) / pooled_std if pooled_std > 0 else 0
            
            # 按患者分组的变异系数
            patient_means = []
            for patient in self.combined_features['patient_id'].unique():
                patient_data = self.combined_features[self.combined_features['patient_id'] == patient][feature_name]
                if len(patient_data) > 0:
                    patient_means.append(patient_data.mean())
            
            between_patient_cv = np.std(patient_means) / np.mean(patient_means) if len(patient_means) > 0 and np.mean(patient_means) != 0 else 0
            
            summary_stats.append({
                'feature_name': feature_name,
                'overall_mean': overall_mean,
                'overall_std': overall_std,
                'grey_mean': grey_mean,
                'grey_std': grey_std,
                'white_mean': white_mean,
                'white_std': white_std,
                'mean_difference': grey_mean - white_mean,
                'cohens_d': cohens_d,
                't_statistic': t_stat,
                'p_value': p_value,
                'significant': p_value < 0.05,
                'between_patient_cv': between_patient_cv
            })
        
        summary_df = pd.DataFrame(summary_stats)
        summary_df = summary_df.sort_values('p_value')  # 按p值排序
        
        # 保存统计结果
        summary_df.to_csv(os.path.join(self.output_folder, 'feature_statistics.csv'), index=False)
        
        # 创建统计可视化
        self._plot_feature_statistics(summary_df)
        
        print(f"  ✓ Feature statistics saved")
        return summary_df
    
    def _plot_feature_statistics(self, summary_df):
        """可视化特征统计结果"""
        
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        # 1. Effect size (Cohen's d) 分布
        axes[0, 0].hist(summary_df['cohens_d'], bins=20, alpha=0.7, edgecolor='black')
        axes[0, 0].axvline(0, color='red', linestyle='--', alpha=0.7)
        axes[0, 0].set_xlabel("Cohen's d")
        axes[0, 0].set_ylabel('Frequency')
        axes[0, 0].set_title('Effect Size Distribution')
        axes[0, 0].grid(True, alpha=0.3)
        
        # 2. P-value 分布
        axes[0, 1].hist(summary_df['p_value'], bins=20, alpha=0.7, edgecolor='black')
        axes[0, 1].axvline(0.05, color='red', linestyle='--', alpha=0.7, label='p=0.05')
        axes[0, 1].set_xlabel('P-value')
        axes[0, 1].set_ylabel('Frequency')
        axes[0, 1].set_title('P-value Distribution')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)
        
        # 3. 患者间变异系数
        axes[1, 0].hist(summary_df['between_patient_cv'], bins=20, alpha=0.7, edgecolor='black')
        axes[1, 0].set_xlabel('Between-Patient CV')
        axes[1, 0].set_ylabel('Frequency')
        axes[1, 0].set_title('Between-Patient Variability')
        axes[1, 0].grid(True, alpha=0.3)
        
        # 4. Effect size vs P-value
        significant = summary_df['significant']
        axes[1, 1].scatter(summary_df.loc[~significant, 'cohens_d'], 
                          summary_df.loc[~significant, 'p_value'], 
                          alpha=0.6, label='Non-significant', color='gray')
        axes[1, 1].scatter(summary_df.loc[significant, 'cohens_d'], 
                          summary_df.loc[significant, 'p_value'], 
                          alpha=0.8, label='Significant', color='red')
        axes[1, 1].axhline(0.05, color='red', linestyle='--', alpha=0.7)
        axes[1, 1].axvline(0, color='gray', linestyle='-', alpha=0.5)
        axes[1, 1].set_xlabel("Cohen's d")
        axes[1, 1].set_ylabel('P-value')
        axes[1, 1].set_title('Effect Size vs Statistical Significance')
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)
        axes[1, 1].set_yscale('log')
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.output_folder, 'feature_statistics_plots.png'), 
                   dpi=300, bbox_inches='tight')
        plt.close()
    
    def create_patient_comparison_plots(self):
        """创建患者间比较图"""
        
        if self.combined_features is None:
            print("No features loaded. Run load_and_extract_features() first.")
            return
        
        print(f"\nCreating patient comparison plots...")
        
        # 计算每个患者的特征均值
        patient_means = []
        patients = self.combined_features['patient_id'].unique()
        
        for patient in patients:
            patient_data = self.combined_features[self.combined_features['patient_id'] == patient]
            
            # 分别计算grey和white matter的均值
            grey_means = patient_data[patient_data['matter_type'] == 'Grey'][self.feature_names].mean()
            white_means = patient_data[patient_data['matter_type'] == 'White'][self.feature_names].mean()
            
            patient_means.append({
                'patient_id': patient,
                'matter_type': 'Grey',
                **grey_means.to_dict()
            })
            
            patient_means.append({
                'patient_id': patient,
                'matter_type': 'White', 
                **white_means.to_dict()
            })
        
        patient_means_df = pd.DataFrame(patient_means)
        
        # 1. 患者相似性热图
        self._create_patient_similarity_heatmap(patient_means_df)
        
        # 2. PCA分析
        self._create_pca_analysis()
        
        # 3. 最具区分性的特征
        self._create_discriminative_features_plot()
        
        print(f"  ✓ Patient comparison plots saved")
    
    def _create_patient_similarity_heatmap(self, patient_means_df):
        """创建患者相似性热图"""
        
        # 只使用grey matter数据进行患者相似性分析
        grey_means = patient_means_df[patient_means_df['matter_type'] == 'Grey']
        
        # 计算患者间的相关性
        patients = grey_means['patient_id'].values
        feature_matrix = grey_means[self.feature_names].values
        
        # 标准化
        scaler = StandardScaler()
        feature_matrix_scaled = scaler.fit_transform(feature_matrix)
        
        # 计算相关性矩阵
        correlation_matrix = np.corrcoef(feature_matrix_scaled)
        
        # 创建热图
        plt.figure(figsize=(10, 8))
        mask = np.triu(np.ones_like(correlation_matrix, dtype=bool))
        sns.heatmap(correlation_matrix, mask=mask, annot=True, cmap='coolwarm', center=0,
                   xticklabels=patients, yticklabels=patients, fmt='.2f')
        plt.title('Patient Similarity (Grey Matter Features)')
        plt.tight_layout()
        plt.savefig(os.path.join(self.output_folder, 'patient_similarity_heatmap.png'), 
                   dpi=300, bbox_inches='tight')
        plt.close()
    
    def _create_pca_analysis(self):
        """创建PCA分析图"""
        
        # 准备数据
        feature_data = self.combined_features[self.feature_names]
        labels = self.combined_features['matter_type']
        patients = self.combined_features['patient_id']
        
        # 标准化
        scaler = StandardScaler()
        feature_data_scaled = scaler.fit_transform(feature_data)
        
        # PCA
        pca = PCA(n_components=10)
        pca_result = pca.fit_transform(feature_data_scaled)
        
        # 创建图表
        fig, axes = plt.subplots(1, 2, figsize=(15, 6))
        
        # 1. 按matter type着色
        unique_patients = patients.unique()
        patient_color_map = {patient: self.patient_colors[i % len(self.patient_colors)] 
                           for i, patient in enumerate(unique_patients)}
        
        for matter_type in ['Grey', 'White']:
            mask = labels == matter_type
            marker = 'o' if matter_type == 'Grey' else 's'
            axes[0].scatter(pca_result[mask, 0], pca_result[mask, 1], 
                          alpha=0.6, label=matter_type, s=30, marker=marker)
        
        axes[0].set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%} variance)')
        axes[0].set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%} variance)')
        axes[0].set_title('PCA: Grey vs White Matter')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)
        
        # 2. 按患者着色
        for patient in unique_patients:
            mask = patients == patient
            if np.any(mask):
                axes[1].scatter(pca_result[mask, 0], pca_result[mask, 1], 
                              c=[patient_color_map[patient]], alpha=0.6, label=patient, s=30)
        
        axes[1].set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%} variance)')
        axes[1].set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%} variance)')
        axes[1].set_title('PCA: By Patient')
        axes[1].legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        axes[1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.output_folder, 'pca_analysis.png'), 
                   dpi=300, bbox_inches='tight')
        plt.close()
        
        # 保存主成分信息
        component_df = pd.DataFrame({
            'feature': self.feature_names,
            'PC1': pca.components_[0],
            'PC2': pca.components_[1]
        })
        component_df = component_df.reindex(component_df['PC1'].abs().sort_values(ascending=False).index)
        component_df.to_csv(os.path.join(self.output_folder, 'pca_components.csv'), index=False)
    
    def _create_discriminative_features_plot(self):
        """创建最具区分性特征的可视化"""
        
        # 读取之前计算的特征统计
        stats_file = os.path.join(self.output_folder, 'feature_statistics.csv')
        if os.path.exists(stats_file):
            stats_df = pd.read_csv(stats_file)
        else:
            stats_df = self.create_feature_summary_statistics()
        
        # 选择最具区分性的特征（按Cohen's d排序）
        top_features = stats_df.nlargest(12, 'cohens_d')['feature_name'].tolist()
        
        # 创建箱线图
        fig, axes = plt.subplots(3, 4, figsize=(16, 12))
        axes = axes.flatten()
        
        for i, feature in enumerate(top_features):
            ax = axes[i]
            
            # 准备数据
            grey_data = self.combined_features[self.combined_features['matter_type'] == 'Grey'][feature]
            white_data = self.combined_features[self.combined_features['matter_type'] == 'White'][feature]
            
            # 箱线图
            box_data = [white_data, grey_data]
            bp = ax.boxplot(box_data, labels=['White', 'Grey'], patch_artist=True)
            bp['boxes'][0].set_facecolor('lightblue')
            bp['boxes'][1].set_facecolor('lightcoral')
            
            # 添加统计信息
            stat_info = stats_df[stats_df['feature_name'] == feature].iloc[0]
            ax.set_title(f'{feature}\nCohen\'s d: {stat_info["cohens_d"]:.3f}\np: {stat_info["p_value"]:.3e}',
                        fontsize=10)
            ax.grid(True, alpha=0.3)
        
        # 隐藏空的子图
        for i in range(len(top_features), len(axes)):
            axes[i].set_visible(False)
        
        plt.suptitle('Most Discriminative Features (Top 12 by Cohen\'s d)', fontsize=14)
        plt.tight_layout()
        plt.savefig(os.path.join(self.output_folder, 'discriminative_features.png'), 
                   dpi=300, bbox_inches='tight')
        plt.close()
    
    def create_patient_specific_analysis(self):
        """创建患者特异性分析"""
        
        if self.combined_features is None:
            print("No features loaded. Run load_and_extract_features() first.")
            return
        
        print(f"\nCreating patient-specific analysis...")
        
        patients = self.combined_features['patient_id'].unique()
        
        # 计算每个患者的特征分布特性
        patient_analysis = []
        
        for patient in patients:
            patient_data = self.combined_features[self.combined_features['patient_id'] == patient]
            
            grey_data = patient_data[patient_data['matter_type'] == 'Grey'][self.feature_names]
            white_data = patient_data[patient_data['matter_type'] == 'White'][self.feature_names]
            
            # 计算分离性指标
            separability_scores = []
            significant_features = 0
            
            for feature in self.feature_names:
                if len(grey_data) > 5 and len(white_data) > 5:
                    try:
                        # t检验
                        t_stat, p_val = stats.ttest_ind(grey_data[feature], white_data[feature])
                        
                        # Cohen's d
                        pooled_std = np.sqrt(((len(grey_data) - 1) * grey_data[feature].std()**2 + 
                                            (len(white_data) - 1) * white_data[feature].std()**2) / 
                                           (len(grey_data) + len(white_data) - 2))
                        cohens_d = abs((grey_data[feature].mean() - white_data[feature].mean()) / pooled_std) if pooled_std > 0 else 0
                        
                        separability_scores.append(cohens_d)
                        if p_val < 0.05:
                            significant_features += 1
                    except:
                        separability_scores.append(0)
                else:
                    separability_scores.append(0)
            
            patient_analysis.append({
                'patient_id': patient,
                'n_grey_samples': len(grey_data),
                'n_white_samples': len(white_data),
                'mean_separability': np.mean(separability_scores),
                'max_separability': np.max(separability_scores),
                'significant_features': significant_features,
                'separability_std': np.std(separability_scores)
            })
        
        patient_analysis_df = pd.DataFrame(patient_analysis)
        patient_analysis_df.to_csv(os.path.join(self.output_folder, 'patient_analysis.csv'), index=False)
        
        # 可视化患者分析结果
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        
        # 1. 平均分离性
        axes[0, 0].bar(patient_analysis_df['patient_id'], patient_analysis_df['mean_separability'])
        axes[0, 0].set_xlabel('Patient ID')
        axes[0, 0].set_ylabel('Mean Separability (Cohen\'s d)')
        axes[0, 0].set_title('Average Feature Separability by Patient')
        axes[0, 0].tick_params(axis='x', rotation=45)
        axes[0, 0].grid(True, alpha=0.3)
        
        # 2. 显著特征数量
        axes[0, 1].bar(patient_analysis_df['patient_id'], patient_analysis_df['significant_features'])
        axes[0, 1].set_xlabel('Patient ID')
        axes[0, 1].set_ylabel('Number of Significant Features')
        axes[0, 1].set_title('Significant Features by Patient (p<0.05)')
        axes[0, 1].tick_params(axis='x', rotation=45)
        axes[0, 1].grid(True, alpha=0.3)
        
        # 3. 样本数量
        x = np.arange(len(patient_analysis_df))
        width = 0.35
        axes[1, 0].bar(x - width/2, patient_analysis_df['n_grey_samples'], width, label='Grey', alpha=0.8)
        axes[1, 0].bar(x + width/2, patient_analysis_df['n_white_samples'], width, label='White', alpha=0.8)
        axes[1, 0].set_xlabel('Patient ID')
        axes[1, 0].set_ylabel('Sample Count')
        axes[1, 0].set_title('Sample Distribution by Patient')
        axes[1, 0].set_xticks(x)
        axes[1, 0].set_xticklabels(patient_analysis_df['patient_id'], rotation=45)
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)
        
        # 4. 分离性vs样本数量
        total_samples = patient_analysis_df['n_grey_samples'] + patient_analysis_df['n_white_samples']
        axes[1, 1].scatter(total_samples, patient_analysis_df['mean_separability'], s=60, alpha=0.7)
        for i, patient in enumerate(patient_analysis_df['patient_id']):
            axes[1, 1].annotate(patient, (total_samples.iloc[i], patient_analysis_df['mean_separability'].iloc[i]),
                              xytext=(5, 5), textcoords='offset points', fontsize=8)
        axes[1, 1].set_xlabel('Total Samples')
        axes[1, 1].set_ylabel('Mean Separability')
        axes[1, 1].set_title('Separability vs Sample Size')
        axes[1, 1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.output_folder, 'patient_specific_analysis.png'), 
                   dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"  ✓ Patient-specific analysis saved")
        return patient_analysis_df
    
    def run_complete_visualization(self):
        """运行完整的可视化分析"""
        
        print(f"Starting Complete Patient Feature Visualization")
        print(f"=" * 60)
        
        # 1. 加载和提取特征
        if not self.load_and_extract_features():
            print("❌ Failed to load and extract features")
            return None
        
        # 2. 创建个别特征分布图
        self.create_individual_feature_plots()
        
        # 3. 创建特征统计汇总
        stats_df = self.create_feature_summary_statistics()
        
        # 4. 创建患者比较图
        self.create_patient_comparison_plots()
        
        # 5. 创建患者特异性分析
        patient_analysis = self.create_patient_specific_analysis()
        
        # 6. 创建汇总报告
        self._create_summary_report(stats_df, patient_analysis)
        
        print(f"\n{'='*60}")
        print(f"Visualization Analysis Complete")
        print(f"{'='*60}")
        print(f"📁 Results saved to: {self.output_folder}")
        print(f"   - feature_distributions_page_*.png: Individual feature distributions")
        print(f"   - feature_statistics.csv: Feature statistics summary")
        print(f"   - feature_statistics_plots.png: Statistical analysis plots")
        print(f"   - patient_similarity_heatmap.png: Patient similarity analysis")
        print(f"   - pca_analysis.png: PCA analysis")
        print(f"   - discriminative_features.png: Most discriminative features")
        print(f"   - patient_specific_analysis.png: Patient-specific analysis")
        print(f"   - summary_report.txt: Text summary report")
        
        return {
            'feature_stats': stats_df,
            'patient_analysis': patient_analysis,
            'n_patients': len(self.combined_features['patient_id'].unique()),
            'n_features': len(self.feature_names),
            'total_samples': len(self.combined_features)
        }
    
    def _create_summary_report(self, stats_df, patient_analysis):
        """创建文本汇总报告"""
        
        report_path = os.path.join(self.output_folder, 'summary_report.txt')
        
        with open(report_path, 'w') as f:
            f.write("PATIENT FEATURE VISUALIZATION ANALYSIS REPORT\n")
            f.write("=" * 60 + "\n\n")
            
            # 基本信息
            f.write("BASIC INFORMATION:\n")
            f.write("-" * 20 + "\n")
            f.write(f"Number of patients: {len(self.combined_features['patient_id'].unique())}\n")
            f.write(f"Number of features: {len(self.feature_names)}\n")
            f.write(f"Total samples: {len(self.combined_features)}\n")
            
            grey_samples = len(self.combined_features[self.combined_features['matter_type'] == 'Grey'])
            white_samples = len(self.combined_features[self.combined_features['matter_type'] == 'White'])
            f.write(f"Grey matter samples: {grey_samples}\n")
            f.write(f"White matter samples: {white_samples}\n\n")
            
            # 特征分析
            f.write("FEATURE ANALYSIS:\n")
            f.write("-" * 20 + "\n")
            significant_features = len(stats_df[stats_df['significant'] == True])
            f.write(f"Statistically significant features (p<0.05): {significant_features}/{len(stats_df)}\n")
            
            high_effect_features = len(stats_df[stats_df['cohens_d'].abs() > 0.5])
            f.write(f"Features with large effect size (|Cohen's d| > 0.5): {high_effect_features}/{len(stats_df)}\n")
            
            f.write(f"\nTop 10 most discriminative features (by Cohen's d):\n")
            top_10 = stats_df.nlargest(10, 'cohens_d')
            for i, (_, row) in enumerate(top_10.iterrows(), 1):
                f.write(f"  {i:2d}. {row['feature_name']:20s} (d={row['cohens_d']:6.3f}, p={row['p_value']:8.3e})\n")
            
            # 患者分析
            f.write(f"\nPATIENT ANALYSIS:\n")
            f.write("-" * 20 + "\n")
            f.write(f"Average separability across patients: {patient_analysis['mean_separability'].mean():.3f} ± {patient_analysis['mean_separability'].std():.3f}\n")
            f.write(f"Patient with highest separability: {patient_analysis.loc[patient_analysis['mean_separability'].idxmax(), 'patient_id']} ({patient_analysis['mean_separability'].max():.3f})\n")
            f.write(f"Patient with lowest separability: {patient_analysis.loc[patient_analysis['mean_separability'].idxmin(), 'patient_id']} ({patient_analysis['mean_separability'].min():.3f})\n")
            
            f.write(f"\nAverage significant features per patient: {patient_analysis['significant_features'].mean():.1f} ± {patient_analysis['significant_features'].std():.1f}\n")
            
            # 问题诊断
            f.write(f"\nPOTENTIAL ISSUES IDENTIFIED:\n")
            f.write("-" * 30 + "\n")
            
            if significant_features < len(stats_df) * 0.1:
                f.write("⚠️  Very few features show statistical significance between grey/white matter\n")
            
            if patient_analysis['mean_separability'].std() > patient_analysis['mean_separability'].mean() * 0.5:
                f.write("⚠️  High variability in separability across patients\n")
            
            low_sep_patients = len(patient_analysis[patient_analysis['mean_separability'] < 0.2])
            if low_sep_patients > len(patient_analysis) * 0.3:
                f.write(f"⚠️  {low_sep_patients} patients show very low feature separability (<0.2)\n")
            
            high_cv_features = len(stats_df[stats_df['between_patient_cv'] > 1.0])
            if high_cv_features > len(stats_df) * 0.2:
                f.write(f"⚠️  {high_cv_features} features show high between-patient variability (CV>1.0)\n")
            
            f.write(f"\nRECOMMENDations:\n")
            f.write("-" * 20 + "\n")
            f.write("1. Consider patient-specific normalization or feature selection\n")
            f.write("2. Investigate data preprocessing consistency across patients\n")
            f.write("3. Consider hierarchical models that account for patient variability\n")
            f.write("4. Focus on the most discriminative features for model training\n")
            f.write("5. Consider removing patients with very low separability\n")

# 使用示例函数
def analyze_patient_features(processed_folder, output_folder='feature_analysis_results'):
    """
    便捷的分析函数
    
    Parameters:
    -----------
    processed_folder : str
        包含患者数据的文件夹路径
    output_folder : str
        输出文件夹路径
    
    Returns:
    --------
    dict: 分析结果字典
    """
    
    visualizer = PatientFeatureVisualizer(processed_folder, output_folder)
    results = visualizer.run_complete_visualization()
    
    return results

In [None]:
# 患者特征可视化分析 - 使用示例
# 方法1: 使用便捷函数
def quick_analysis():
    """快速分析示例"""
    
    processed_folder = r"D:\BlcRepo\LabCode\SeizureProp\data\gwbaseline"
    output_folder = r"D:\BlcRepo\LabCode\SeizureProp\result\feature_visualization"
    
    print("🔍 Starting Patient Feature Visualization Analysis...")
    
    # 运行完整分析
    results = analyze_patient_features(processed_folder, output_folder)
    
    if results:
        print(f"\n✅ Analysis completed successfully!")
        print(f"📊 Key findings:")
        print(f"   - {results['n_patients']} patients analyzed")
        print(f"   - {results['n_features']} features extracted")
        print(f"   - {results['total_samples']} total samples")
        
        # 显示一些关键统计
        stats_df = results['feature_stats']
        significant_count = len(stats_df[stats_df['significant'] == True])
        print(f"   - {significant_count}/{len(stats_df)} features are statistically significant")
        
        # 显示最有区分性的特征
        top_features = stats_df.nlargest(5, 'cohens_d')['feature_name'].tolist()
        print(f"   - Top discriminative features: {', '.join(top_features[:3])}...")
    else:
        print("❌ Analysis failed")

# 方法2: 详细的步骤控制
def detailed_analysis():
    """详细分析示例，可以控制每个步骤"""
    
    processed_folder = r"D:\BlcRepo\LabCode\SeizureProp\data\gwbaseline_alt"
    output_folder = r"D:\BlcRepo\LabCode\SeizureProp\result\detailed_feature_analysis"
    
    # 创建可视化器
    visualizer = PatientFeatureVisualizer(processed_folder, output_folder)
    
    # 步骤1: 加载数据和提取特征
    print("📂 Loading patient data and extracting features...")
    if not visualizer.load_and_extract_features():
        print("❌ Failed to load data")
        return
    
    print(f"✅ Loaded {len(visualizer.combined_features['patient_id'].unique())} patients")
    print(f"📊 Total samples: {len(visualizer.combined_features)}")
    print(f"🔢 Features: {len(visualizer.feature_names)}")
    
    # 步骤2: 创建个别特征分布图
    print("\n📈 Creating individual feature distribution plots...")
    visualizer.create_individual_feature_plots(n_features_per_page=12)  # 每页12个特征
    
    # 步骤3: 统计分析
    print("\n📊 Creating feature statistics...")
    stats_df = visualizer.create_feature_summary_statistics()
    
    # 显示一些关键发现
    significant_features = stats_df[stats_df['significant'] == True]
    print(f"   🎯 {len(significant_features)} features are statistically significant")
    
    if len(significant_features) > 0:
        best_feature = significant_features.loc[significant_features['cohens_d'].abs().idxmax()]
        print(f"   🏆 Best discriminative feature: {best_feature['feature_name']}")
        print(f"      - Cohen's d: {best_feature['cohens_d']:.3f}")
        print(f"      - p-value: {best_feature['p_value']:.3e}")
    
    # 步骤4: 患者比较分析
    print("\n👥 Creating patient comparison analysis...")
    visualizer.create_patient_comparison_plots()
    
    # 步骤5: 患者特异性分析
    print("\n🔍 Creating patient-specific analysis...")
    patient_analysis = visualizer.create_patient_specific_analysis()
    
    # 显示患者相关发现
    best_patient = patient_analysis.loc[patient_analysis['mean_separability'].idxmax()]
    worst_patient = patient_analysis.loc[patient_analysis['mean_separability'].idxmin()]
    
    print(f"   📈 Best separability: {best_patient['patient_id']} (d={best_patient['mean_separability']:.3f})")
    print(f"   📉 Worst separability: {worst_patient['patient_id']} (d={worst_patient['mean_separability']:.3f})")
    
    # 识别潜在问题
    low_sep_patients = patient_analysis[patient_analysis['mean_separability'] < 0.2]
    if len(low_sep_patients) > 0:
        print(f"   ⚠️  {len(low_sep_patients)} patients have very low separability:")
        for _, patient in low_sep_patients.iterrows():
            print(f"      - {patient['patient_id']}: {patient['mean_separability']:.3f}")
    
    print(f"\n📁 All results saved to: {output_folder}")
    
    return {
        'visualizer': visualizer,
        'stats': stats_df,
        'patient_analysis': patient_analysis
    }

# 方法3: 针对性分析特定特征
def analyze_specific_features(feature_list=None):
    """分析特定特征的示例"""
    
    processed_folder = r"D:\BlcRepo\LabCode\SeizureProp\data\gwbaseline_alt"
    output_folder = r"D:\BlcRepo\LabCode\SeizureProp\result\specific_features"
    
    visualizer = PatientFeatureVisualizer(processed_folder, output_folder)
    
    # 加载数据
    if not visualizer.load_and_extract_features():
        return
    
    # 如果没有指定特征，使用最具区分性的特征
    if feature_list is None:
        stats_df = visualizer.create_feature_summary_statistics()
        feature_list = stats_df.nlargest(8, 'cohens_d')['feature_name'].tolist()
    
    print(f"🎯 Analyzing specific features: {feature_list}")
    
    # 创建针对性可视化
    fig, axes = plt.subplots(2, 4, figsize=(20, 10))
    axes = axes.flatten()
    
    patients = visualizer.combined_features['patient_id'].unique()
    patient_colors = plt.cm.Set3(np.linspace(0, 1, len(patients)))
    patient_color_map = {patient: patient_colors[i] for i, patient in enumerate(patients)}
    
    for i, feature in enumerate(feature_list[:8]):
        ax = axes[i]
        
        # 为每个患者绘制散点
        for j, patient in enumerate(patients):
            patient_data = visualizer.combined_features[visualizer.combined_features['patient_id'] == patient]
            
            grey_data = patient_data[patient_data['matter_type'] == 'Grey']
            white_data = patient_data[patient_data['matter_type'] == 'White']
            
            if len(grey_data) > 0:
                y_grey = np.random.normal(1, 0.05, len(grey_data))
                ax.scatter(grey_data[feature], y_grey, c=[patient_color_map[patient]], 
                          alpha=0.6, s=15, marker='o')
            
            if len(white_data) > 0:
                y_white = np.random.normal(0, 0.05, len(white_data))
                ax.scatter(white_data[feature], y_white, c=[patient_color_map[patient]], 
                          alpha=0.6, s=15, marker='s')
        
        ax.set_xlabel(feature)
        ax.set_ylabel('Matter Type')
        ax.set_yticks([0, 1])
        ax.set_yticklabels(['White', 'Grey'])
        ax.grid(True, alpha=0.3)
        ax.set_title(feature, fontsize=10)
    
    # 隐藏空的子图
    for i in range(len(feature_list), 8):
        axes[i].set_visible(False)
    
    # 添加图例
    handles = []
    labels = []
    for patient in patients:
        handles.append(plt.Line2D([0], [0], marker='o', color='w', 
                                markerfacecolor=patient_color_map[patient], markersize=8))
        labels.append(patient)
    
    plt.figlegend(handles, labels, loc='center', bbox_to_anchor=(0.5, 0.02), 
                 ncol=min(6, len(patients)), fontsize=10)
    
    plt.suptitle('Patient-Specific Feature Analysis', fontsize=16)
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.15)
    plt.savefig(os.path.join(output_folder, 'specific_features_analysis.png'), 
               dpi=300, bbox_inches='tight')
    plt.show()


def troubleshoot_classification_issues():
    """诊断分类问题的专门分析"""
    
    processed_folder = r"D:\BlcRepo\LabCode\SeizureProp\data\gwbaseline"
    output_folder = r"D:\BlcRepo\LabCode\SeizureProp\result\troubleshooting"
    
    visualizer = PatientFeatureVisualizer(processed_folder, output_folder)
    
    if not visualizer.load_and_extract_features():
        return
    
    print("🔧 TROUBLESHOOTING CLASSIFICATION ISSUES")
    print("=" * 50)
    
    # 1. 检查数据平衡性
    patients = visualizer.combined_features['patient_id'].unique()
    print(f"\n1. DATA BALANCE CHECK:")
    print("-" * 25)
    
    total_grey = 0
    total_white = 0
    imbalanced_patients = []
    
    for patient in patients:
        patient_data = visualizer.combined_features[visualizer.combined_features['patient_id'] == patient]
        n_grey = len(patient_data[patient_data['matter_type'] == 'Grey'])
        n_white = len(patient_data[patient_data['matter_type'] == 'White'])
        
        total_grey += n_grey
        total_white += n_white
        
        ratio = min(n_grey, n_white) / max(n_grey, n_white) if max(n_grey, n_white) > 0 else 0
        
        print(f"   {patient}: Grey={n_grey:4d}, White={n_white:4d}, Ratio={ratio:.2f}")
        
        if ratio < 0.3:  # 严重不平衡
            imbalanced_patients.append(patient)
    
    print(f"\n   Overall: Grey={total_grey}, White={total_white}")
    overall_ratio = min(total_grey, total_white) / max(total_grey, total_white)
    print(f"   Overall Ratio: {overall_ratio:.3f}")
    
    if imbalanced_patients:
        print(f"   ⚠️  Severely imbalanced patients: {imbalanced_patients}")
    
    # 2. 检查特征质量
    print(f"\n2. FEATURE QUALITY CHECK:")
    print("-" * 30)
    
    stats_df = visualizer.create_feature_summary_statistics()
    
    # 无区分性的特征
    weak_features = stats_df[stats_df['cohens_d'].abs() < 0.1]
    print(f"   Weak features (|Cohen's d| < 0.1): {len(weak_features)}/{len(stats_df)}")
    
    # 高变异性特征
    high_var_features = stats_df[stats_df['between_patient_cv'] > 1.5]
    print(f"   High variability features (CV > 1.5): {len(high_var_features)}/{len(stats_df)}")
    
    # 显著性特征
    significant_features = stats_df[stats_df['significant'] == True]
    print(f"   Statistically significant features: {len(significant_features)}/{len(stats_df)}")
    
    if len(significant_features) < len(stats_df) * 0.1:
        print("   ⚠️  Very few significant features - possible data quality issue!")
    
    # 3. 患者异质性检查
    print(f"\n3. PATIENT HETEROGENEITY CHECK:")
    print("-" * 35)
    
    patient_analysis = visualizer.create_patient_specific_analysis()
    
    low_sep_patients = patient_analysis[patient_analysis['mean_separability'] < 0.2]
    print(f"   Patients with low separability (<0.2): {len(low_sep_patients)}/{len(patient_analysis)}")
    
    if len(low_sep_patients) > 0:
        print("   Low separability patients:")
        for _, patient in low_sep_patients.iterrows():
            print(f"      {patient['patient_id']}: {patient['mean_separability']:.3f}")
    
    # 计算患者间相似性
    feature_matrix = []
    patient_list = []
    
    for patient in patients:
        patient_data = visualizer.combined_features[visualizer.combined_features['patient_id'] == patient]
        grey_data = patient_data[patient_data['matter_type'] == 'Grey'][visualizer.feature_names]
        
        if len(grey_data) > 0:
            patient_features = grey_data.mean().values
            feature_matrix.append(patient_features)
            patient_list.append(patient)
    
    if len(feature_matrix) > 1:
        from sklearn.preprocessing import StandardScaler
        from scipy.spatial.distance import pdist, squareform
        
        # 标准化并计算距离
        scaler = StandardScaler()
        feature_matrix_scaled = scaler.fit_transform(feature_matrix)
        distances = pdist(feature_matrix_scaled, metric='euclidean')
        distance_matrix = squareform(distances)
        
        avg_distance = np.mean(distances)
        print(f"   Average inter-patient distance: {avg_distance:.3f}")
        
        if avg_distance > 5.0:
            print("   ⚠️  High inter-patient variability detected!")
    
    # 4. 建议生成
    print(f"\n4. RECOMMENDATIONS:")
    print("-" * 20)
    
    recommendations = []
    
    if overall_ratio < 0.7:
        recommendations.append("Use balanced sampling or class weights in training")
    
    if len(imbalanced_patients) > len(patients) * 0.3:
        recommendations.append("Consider removing severely imbalanced patients")
    
    if len(weak_features) > len(stats_df) * 0.5:
        recommendations.append("Perform feature selection to remove weak features")
    
    if len(high_var_features) > len(stats_df) * 0.3:
        recommendations.append("Consider patient-specific normalization")
    
    if len(low_sep_patients) > len(patients) * 0.3:
        recommendations.append("Investigate data preprocessing consistency")
        recommendations.append("Consider hierarchical/mixed-effects models")
    
    if len(significant_features) < len(stats_df) * 0.1:
        recommendations.append("Review feature extraction methodology")
        recommendations.append("Consider different time windows or signal processing")
    
    for i, rec in enumerate(recommendations, 1):
        print(f"   {i}. {rec}")
    
    if not recommendations:
        print("   ✅ No major issues detected in the data")
    
    # 5. 生成改进建议的具体代码
    print(f"\n5. IMPLEMENTATION SUGGESTIONS:")
    print("-" * 35)
    
    print("   # 特征选择示例:")
    print("   from sklearn.feature_selection import SelectKBest, f_classif")
    print("   selector = SelectKBest(f_classif, k=20)")
    print("   X_selected = selector.fit_transform(X, y)")
    
    print("\n   # 类别平衡示例:")
    print("   from sklearn.utils.class_weight import compute_class_weight")
    print("   class_weights = compute_class_weight('balanced', classes=np.unique(y), y=y)")
    
    print("\n   # 患者标准化示例:")
    print("   from sklearn.preprocessing import StandardScaler")
    print("   for patient in patients:")
    print("       scaler = StandardScaler()")
    print("       patient_data = scaler.fit_transform(patient_data)")

In [None]:
detailed_analysis()
troubleshoot_classification_issues()

In [None]:
from scipy.signal import welch
from scipy.integrate import trapezoid
class ChannelClassifier:
    """
    Channel-Level Grey/White Matter分类器
    
    核心特点：
    1. 每个channel作为独立样本（无moving window）
    2. Patient-level标准化
    3. Training-only outlier处理
    4. 精简特征集
    5. 双重验证（Channel-level + Patient-level）
    """
    
    def __init__(self, processed_folder, output_folder='results'):
        self.processed_folder = processed_folder
        self.output_folder = output_folder
        self.patients_data = {}
        self.channels_data = {}  # {channel_id: {features, label, patient_id, ...}}
        
        # 分类器
        self.classifiers = {
            'Logistic Regression': LogisticRegression(max_iter=1000, random_state=42, class_weight='balanced'),
            'SVM': SVC(probability=True, random_state=42, class_weight='balanced'),
            'Random Forest': RandomForestClassifier(n_estimators=100, random_state=42, class_weight='balanced'),
            'MLP': MLPClassifier(max_iter=1000, random_state=42, hidden_layer_sizes=(50,)),
            'KNN': KNeighborsClassifier(n_neighbors=5),
            'LDA': LDA(),
            'Naive Bayes': GaussianNB()
        }
        
        os.makedirs(output_folder, exist_ok=True)
        print(f"🧠 Channel Classifier initialized")
        print(f"   Data: {processed_folder}")
        print(f"   Output: {output_folder}")
    
    def load_patients(self):
        """加载所有患者数据"""
        pkl_files = glob.glob(os.path.join(self.processed_folder, "P*_processed.pkl"))
        print(f"\n📂 Loading {len(pkl_files)} patients...")
        
        for pkl_file in pkl_files:
            try:
                with open(pkl_file, 'rb') as f:
                    data = pickle.load(f)
                pid = data['patient_id']
                self.patients_data[pid] = data
                duration = data['processing_summary']['total_duration_seconds'] / 60
                print(f"  ✓ {pid}: {duration:.1f} min")
            except Exception as e:
                print(f"  ✗ {os.path.basename(pkl_file)}: {e}")
        
        print(f"✅ Loaded {len(self.patients_data)} patients")
        return len(self.patients_data) > 0
    
    def _extract_electrode_types(self, matter_data):
        """提取电极类型"""
        # 查找matter列
        matter_cols = ['MatterType', 'matter', 'Matter', 'mattertype', 'tissue_type', 'type']
        matter_col = None
        for col in matter_cols:
            if col in matter_data.columns:
                matter_col = col
                break
        
        if matter_col is None:
            raise ValueError(f"Matter column not found in {matter_data.columns.tolist()}")
        
        # 提取grey/white索引
        matter_values = matter_data[matter_col].astype(str).str.lower()
        grey_mask = matter_values.isin(['g', 'grey', 'gray'])
        white_mask = matter_values.isin(['w', 'white'])
        
        # 如果标准格式找不到，尝试模式匹配
        if grey_mask.sum() == 0:
            grey_mask = matter_values.str.contains('grey|gray|cortex', case=False, na=False)
        if white_mask.sum() == 0:
            white_mask = matter_values.str.contains('white', case=False, na=False)
        
        grey_indices = matter_data.index[grey_mask].tolist()
        white_indices = matter_data.index[white_mask].tolist()
        
        return grey_indices, white_indices
    
    def _extract_features(self, signal, fs, avg_power_spectrum=None, freqs=None):
        """提取精简特征集"""
        if len(signal) < 100 or np.all(signal == 0):
            return None
        
        features = {}
        features['rms'] = np.sqrt(np.mean(signal**2))
        features['range'] = np.ptp(signal)
        features['line_length'] = np.sum(np.abs(np.diff(signal)))
        features['energy'] = np.sum(signal**2)
        
        # 频域特征：相对功率
        try:
            signal_freqs, psd = welch(signal, fs=fs, nperseg=min(1024, len(signal)), scaling='density')
            total_power = np.sum(psd)
            
            bands = {
                'delta': (0.5, 4),
                'theta': (4, 8),
                'alpha': (8, 13),
                'beta': (13, 30),
                'gamma': (30, 100),
                'high_gamma': (100, min(200, fs/2))
            }
            
            # 计算每个频段的功率和相对偏移
            for band_name, (low, high) in bands.items():
                band_mask = (signal_freqs >= low) & (signal_freqs <= high)
                if np.any(band_mask):
                    band_power = np.sum(psd[band_mask])
                    features[f'power_{band_name}'] = band_power
                    
                    # 计算该频段相对于患者平均功率谱的垂直偏移
                    if avg_power_spectrum is not None and freqs is not None:
                        avg_band_mask = (freqs >= low) & (freqs <= high)
                        if np.any(avg_band_mask):
                            # 在log scale下计算该频段的平均偏移
                            signal_band_psd_log = np.log10(psd[band_mask] + 1e-12)
                            avg_band_psd_log = np.log10(avg_power_spectrum[avg_band_mask] + 1e-12)
                            
                            # 如果频率点数不同，取平均值
                            signal_avg = np.mean(signal_band_psd_log)
                            avg_avg = np.mean(avg_band_psd_log)
                            
                            features[f'rel_power_{band_name}'] = signal_avg - avg_avg
                        else:
                            features[f'rel_power_{band_name}'] = 0
                    else:
                        features[f'rel_power_{band_name}'] = 0
                else:
                    features[f'power_{band_name}'] = 0
                    features[f'rel_power_{band_name}'] = 0
            
            features['total_power'] = total_power
                
        except Exception as e:
            print(f"Warning: Error in frequency analysis: {e}")
            bands = ['delta', 'theta', 'alpha', 'beta', 'gamma', 'high_gamma']
            for band in bands:
                features[f'power_{band}'] = 0
                features[f'rel_power_{band}'] = 0
            features['total_power'] = 0
        
        return features
    
    def extract_all_features(self):
        """提取所有channels的特征"""
        print(f"\n🔄 Extracting channel features...")
        
        channel_id = 0
        
        for patient_id, patient_data in self.patients_data.items():
            print(f"  Processing {patient_id}...")
            
            try:
                matter_data = patient_data['matter_data']
                recordings = patient_data['recordings']
                
                # 提取电极类型
                grey_indices, white_indices = self._extract_electrode_types(matter_data)
                
                if len(grey_indices) == 0 or len(white_indices) == 0:
                    print(f"    ✗ Missing grey/white electrodes")
                    continue
                
                print(f"    {len(grey_indices)} grey, {len(white_indices)} white electrodes")
                
                # 合并所有recording数据
                all_data = []
                fs = recordings[0]['sampling_rate'] if recordings else 512
                
                for recording in recordings:
                    data_segment = recording['neural_data_processed']
                    if isinstance(data_segment, list):
                        all_data.extend(data_segment)
                    else:
                        all_data.append(data_segment)
    
                if not all_data:
                    print(f"    ✗ No data found")
                    continue
                    
                combined_data = np.vstack(all_data)
                duration_min = len(combined_data) / fs / 60
                print(f"    Combined: {combined_data.shape} ({duration_min:.1f} min)")
    
                # 计算患者级别的平均功率谱 (1-150Hz)
                print(f"    Computing patient-level average power spectrum...")
                all_power_spectra = []
                
                for ch in range(combined_data.shape[1]):
                    sig = combined_data[:, ch]
                    if len(sig) > 0 and not np.all(sig == 0):
                        try:
                            freqs_ch, psd_ch = welch(sig, fs=fs, nperseg=min(1024, len(sig)), scaling='density')
                            all_power_spectra.append(psd_ch)
                        except Exception as e:
                            print(f"    Warning: Error computing PSD for channel {ch}: {e}")
     
                if all_power_spectra:
                    all_power_spectra = np.vstack(all_power_spectra)
                    avg_power_spectrum = np.mean(all_power_spectra, axis=0)
                    # 使用第一个成功计算的频率向量作为参考
                    freqs, _ = welch(combined_data[:, 0], fs=fs, nperseg=min(1024, len(combined_data[:, 0])), scaling='density')
                else:
                    avg_power_spectrum = None
                    freqs = None
                    print(f"    Warning: Could not compute reference power spectrum")
                
                # 处理Grey Matter channels
                for electrode_idx in grey_indices:
                    if electrode_idx < combined_data.shape[1]:
                        signal = combined_data[:, electrode_idx]
                        features = self._extract_features(signal, fs, avg_power_spectrum, freqs)
                        
                        if features is not None:
                            self.channels_data[channel_id] = {
                                'features': features,
                                'label': 1,  # Grey = 1
                                'patient_id': patient_id,
                                'electrode_idx': electrode_idx,
                                'matter_type': 'grey',
                                'duration_sec': len(signal) / fs
                            }
                            channel_id += 1
                
                # 处理White Matter channels
                for electrode_idx in white_indices:
                    if electrode_idx < combined_data.shape[1]:
                        signal = combined_data[:, electrode_idx]
                        features = self._extract_features(signal, fs, avg_power_spectrum, freqs)
                        
                        if features is not None:
                            self.channels_data[channel_id] = {
                                'features': features,
                                'label': 0,  # White = 0
                                'patient_id': patient_id,
                                'electrode_idx': electrode_idx,
                                'matter_type': 'white',
                                'duration_sec': len(signal) / fs
                            }
                            channel_id += 1
                            
            except Exception as e:
                print(f"    ✗ Error processing {patient_id}: {e}")
                import traceback
                traceback.print_exc()
                continue
        
        n_grey = sum(1 for ch in self.channels_data.values() if ch['label'] == 1)
        n_white = sum(1 for ch in self.channels_data.values() if ch['label'] == 0)
        
        print(f"\n✅ Extracted {len(self.channels_data)} channels")
        print(f"   Grey: {n_grey}, White: {n_white}")
        
        if self.channels_data:
            first_ch = list(self.channels_data.values())[0]
            print(f"   Features: {len(first_ch['features'])}")
            print(f"   Feature names: {list(first_ch['features'].keys())}")
        
        return len(self.channels_data) > 0
    
    def normalize_by_patient(self, method='robust'):
        """Patient-level标准化"""
        print(f"\n🔄 Patient-level normalization ({method})...")
        
        if not self.channels_data:
            return False
        
        # 按patient分组
        patient_groups = {}
        for ch_id, ch_data in self.channels_data.items():
            pid = ch_data['patient_id']
            if pid not in patient_groups:
                patient_groups[pid] = []
            patient_groups[pid].append((ch_id, ch_data))
        
        # 获取特征名称
        feature_names = list(list(self.channels_data.values())[0]['features'].keys())
        print(f"  Normalizing {len(feature_names)} features across {len(patient_groups)} patients")
        
        # 为每个patient标准化
        for pid, channels in patient_groups.items():
            print(f"    {pid}: {len(channels)} channels")
            
            # 收集该patient的所有特征
            features_matrix = []
            for ch_id, ch_data in channels:
                feature_vector = [ch_data['features'][fname] for fname in feature_names]
                features_matrix.append(feature_vector)
            
            features_df = pd.DataFrame(features_matrix, columns=feature_names)
            
            # 标准化
            normalized_df = self._normalize_dataframe(features_df, method)
            
            # 更新特征
            for i, (ch_id, ch_data) in enumerate(channels):
                normalized_features = normalized_df.iloc[i].to_dict()
                self.channels_data[ch_id]['features'] = normalized_features
                self.channels_data[ch_id]['normalized'] = True
        
        print(f"✅ Patient-level normalization completed")
        return True
    
    def _normalize_dataframe(self, df, method):
        """标准化DataFrame"""
        normalized_df = df.copy()
        
        for col in df.columns:
            values = df[col].values
            
            if method == 'robust':
                median_val = np.median(values)
                mad_val = np.median(np.abs(values - median_val))
                if mad_val > 0:
                    normalized_df[col] = (values - median_val) / (1.4826 * mad_val)
                else:
                    normalized_df[col] = values - median_val
            elif method == 'standard':
                mean_val = np.mean(values)
                std_val = np.std(values)
                if std_val > 0:
                    normalized_df[col] = (values - mean_val) / std_val
                else:
                    normalized_df[col] = values - mean_val
            elif method == 'minmax':
                min_val, max_val = np.min(values), np.max(values)
                if max_val > min_val:
                    normalized_df[col] = (values - min_val) / (max_val - min_val)
                else:
                    normalized_df[col] = np.zeros_like(values)
        
        return normalized_df
    
    def _prepare_data(self):
        """准备训练数据"""
        if not self.channels_data:
            return None, None, None
        
        X = []
        y = []
        info = []
        
        for ch_id, ch_data in self.channels_data.items():
            features = ch_data['features']
            feature_vector = np.array(list(features.values()))
            
            X.append(feature_vector)
            y.append(ch_data['label'])
            info.append({
                'channel_id': ch_id,
                'patient_id': ch_data['patient_id'],
                'matter_type': ch_data['matter_type'],
                'electrode_idx': ch_data['electrode_idx']
            })
        
        return np.array(X), np.array(y), info
    
    def _clip_outliers_training_only(self, X_train, X_test, iqr_factor=1.5):
        """只基于训练集计算outlier边界"""
        X_train_clipped = X_train.copy()
        X_test_clipped = X_test.copy()
        
        for feature_idx in range(X_train.shape[1]):
            train_vals = X_train[:, feature_idx]
            
            q25, q75 = np.percentile(train_vals, [25, 75])
            iqr = q75 - q25
            
            if iqr > 0:
                lower = q25 - iqr_factor * iqr
                upper = q75 + iqr_factor * iqr
                
                X_train_clipped[:, feature_idx] = np.clip(train_vals, lower, upper)
                X_test_clipped[:, feature_idx] = np.clip(X_test[:, feature_idx], lower, upper)
        
        return X_train_clipped, X_test_clipped
    
    def channel_level_validation(self, test_size=0.2, outlier_clip=True, iqr_factor=1.5):
        """Channel-level验证：随机分割"""
        print(f"\n🔍 Channel-Level Validation (random split)")
        
        X, y, info = self._prepare_data()
        if X is None:
            return None
        
        # 随机分割
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=test_size, random_state=42, stratify=y
        )
        
        print(f"  Train: {len(X_train)} ({np.sum(y_train)} grey)")
        print(f"  Test: {len(X_test)} ({np.sum(y_test)} grey)")
        
        # Outlier clipping
        if outlier_clip:
            X_train, X_test = self._clip_outliers_training_only(X_train, X_test, iqr_factor)
            print(f"  Applied outlier clipping (IQR × {iqr_factor})")
        
        # 训练和评估
        results = {}
        for clf_name, clf in self.classifiers.items():
            try:
                clf.fit(X_train, y_train)
                y_pred = clf.predict(X_test)
                y_proba = clf.predict_proba(X_test)[:, 1] if hasattr(clf, 'predict_proba') else np.zeros_like(y_pred)
                
                results[clf_name] = {
                    'accuracy': accuracy_score(y_test, y_pred),
                    'f1': f1_score(y_test, y_pred, zero_division=0),
                    'precision': precision_score(y_test, y_pred, zero_division=0),
                    'recall': recall_score(y_test, y_pred, zero_division=0),
                    'balanced_acc': balanced_accuracy_score(y_test, y_pred),
                    'roc_auc': roc_auc_score(y_test, y_proba) if len(np.unique(y_test)) > 1 else 0.5,
                    'confusion_matrix': confusion_matrix(y_test, y_pred),
                    'y_true': y_test,
                    'y_pred': y_pred,
                    'y_proba': y_proba
                }
                
                print(f"  {clf_name}: F1={results[clf_name]['f1']:.3f}, Acc={results[clf_name]['accuracy']:.3f}")
                
            except Exception as e:
                print(f"  {clf_name}: Error - {e}")
        
        return results
    
    def patient_level_validation(self, outlier_clip=True, iqr_factor=1.5):
        """Patient-level验证：Leave-one-patient-out"""
        print(f"\n🔍 Patient-Level Validation (LOPO)")
        
        X, y, info = self._prepare_data()
        if X is None:
            return None
        
        patients = list(set([inf['patient_id'] for inf in info]))
        print(f"  {len(patients)} patients for LOPO")
        
        if len(patients) < 3:
            print("  ❌ Need at least 3 patients")
            return None
        
        # 存储结果
        cv_results = {name: [] for name in self.classifiers.keys()}
        all_predictions = {name: {'y_true': [], 'y_pred': [], 'y_proba': []} for name in self.classifiers.keys()}
        
        for fold, test_patient in enumerate(patients):
            print(f"\n  Fold {fold+1}/{len(patients)}: Test {test_patient}")
            
            # 分离数据
            train_mask = np.array([inf['patient_id'] != test_patient for inf in info])
            test_mask = ~train_mask
            
            X_train, y_train = X[train_mask], y[train_mask]
            X_test, y_test = X[test_mask], y[test_mask]
            
            print(f"    Train: {len(X_train)} ({np.sum(y_train)} grey)")
            print(f"    Test: {len(X_test)} ({np.sum(y_test)} grey)")
            
            # Outlier clipping
            if outlier_clip:
                X_train, X_test = self._clip_outliers_training_only(X_train, X_test, iqr_factor)
            
            # 训练评估
            for clf_name, clf in self.classifiers.items():
                try:
                    clf.fit(X_train, y_train)
                    y_pred = clf.predict(X_test)
                    y_proba = clf.predict_proba(X_test)[:, 1] if hasattr(clf, 'predict_proba') else np.zeros_like(y_pred)
                    
                    fold_result = {
                        'fold': fold,
                        'test_patient': test_patient,
                        'accuracy': accuracy_score(y_test, y_pred),
                        'f1': f1_score(y_test, y_pred, zero_division=0),
                        'precision': precision_score(y_test, y_pred, zero_division=0),
                        'recall': recall_score(y_test, y_pred, zero_division=0),
                        'balanced_acc': balanced_accuracy_score(y_test, y_pred),
                        'roc_auc': roc_auc_score(y_test, y_proba) if len(np.unique(y_test)) > 1 else 0.5,
                        'n_test': len(y_test)
                    }
                    
                    cv_results[clf_name].append(fold_result)
                    all_predictions[clf_name]['y_true'].extend(y_test)
                    all_predictions[clf_name]['y_pred'].extend(y_pred)
                    all_predictions[clf_name]['y_proba'].extend(y_proba)
                    
                    print(f"      {clf_name}: F1={fold_result['f1']:.3f}")
                    
                except Exception as e:
                    print(f"      {clf_name}: Error - {e}")
        
        return cv_results, all_predictions
    
    def analyze_results(self, channel_results, patient_results):
        """分析结果"""
        print(f"\n{'='*60}")
        print(f"📊 Results Analysis")
        print(f"{'='*60}")
        
        analysis = {'channel': {}, 'patient': {}}
        
        # Channel-level结果
        if channel_results:
            print(f"\n🔍 Channel-Level Results:")
            best_channel_f1 = 0
            best_channel_clf = None
            
            for clf_name, metrics in channel_results.items():
                f1 = metrics['f1']
                acc = metrics['accuracy']
                auc = metrics['roc_auc']
                print(f"  {clf_name}: F1={f1:.3f}, Acc={acc:.3f}, AUC={auc:.3f}")
                
                if f1 > best_channel_f1:
                    best_channel_f1 = f1
                    best_channel_clf = clf_name
            
            print(f"  🏆 Best: {best_channel_clf} (F1={best_channel_f1:.3f})")
            
            analysis['channel'] = {
                'results': channel_results,
                'best_clf': best_channel_clf,
                'best_f1': best_channel_f1
            }
        
        # Patient-level结果
        if patient_results:
            cv_results, all_preds = patient_results
            print(f"\n🔍 Patient-Level Results:")
            
            best_patient_f1 = 0
            best_patient_clf = None
            patient_summary = {}
            
            for clf_name in self.classifiers.keys():
                if len(cv_results[clf_name]) > 0:
                    # CV统计
                    f1_values = [fold['f1'] for fold in cv_results[clf_name]]
                    f1_mean = np.mean(f1_values)
                    f1_std = np.std(f1_values)
                    
                    # 总体指标
                    y_true = np.array(all_preds[clf_name]['y_true'])
                    y_pred = np.array(all_preds[clf_name]['y_pred'])
                    overall_f1 = f1_score(y_true, y_pred, zero_division=0)
                    overall_acc = accuracy_score(y_true, y_pred)
                    
                    patient_summary[clf_name] = {
                        'cv_f1_mean': f1_mean,
                        'cv_f1_std': f1_std,
                        'overall_f1': overall_f1,
                        'overall_acc': overall_acc
                    }
                    
                    print(f"  {clf_name}: CV F1={f1_mean:.3f}±{f1_std:.3f}, Overall F1={overall_f1:.3f}")
                    
                    if overall_f1 > best_patient_f1:
                        best_patient_f1 = overall_f1
                        best_patient_clf = clf_name
            
            print(f"  🏆 Best: {best_patient_clf} (F1={best_patient_f1:.3f})")
            
            analysis['patient'] = {
                'results': patient_summary,
                'best_clf': best_patient_clf,
                'best_f1': best_patient_f1
            }
        
        return analysis
    
    def create_visualizations(self, analysis_results):
        """创建可视化"""
        print(f"\n📊 Creating visualizations...")
        
        # 1. 数据集分布
        self._plot_dataset_distribution()
        
        # 2. 验证对比
        self._plot_validation_comparison(analysis_results)
        
        # 3. 特征重要性
        self._plot_feature_importance()
        
        # 4. 患者特征分布 (新增)
        self._plot_patient_feature_distributions()
        
        print(f"  💾 Saved to {self.output_folder}")
    
    def _plot_patient_feature_distributions(self, selected_features=None, max_patients_per_plot=12):
        """
        为每个feature创建单独的可视化图，每个图包含所有患者的该feature分布
        
        Parameters:
        -----------
        selected_features : list, optional
            要可视化的特征列表，如果为None则使用所有特征
        max_patients_per_plot : int
            每个图最多显示的患者数量
        """
        
        print(f"  📊 Creating patient feature distributions...")
        
        if not self.channels_data:
            print(f"    ❌ No channel data available")
            return
        
        # 获取特征名称
        first_channel = list(self.channels_data.values())[0]
        all_feature_names = list(first_channel['features'].keys())
        
        if selected_features is None:
            # 选择一些代表性特征进行可视化
            selected_features = [
                'mean', 'std', 'rms', 'line_length', 'energy',
                'rel_power_delta', 'rel_power_theta', 'rel_power_alpha', 
                'rel_power_beta', 'rel_power_gamma', 'rel_power_high_gamma',
                'skewness', 'kurtosis'
            ]
            # 只保留实际存在的特征
            selected_features = [f for f in selected_features if f in all_feature_names]
        
        # 组织数据：按patient和matter_type分组
        patient_feature_data = {}
        
        for ch_id, ch_data in self.channels_data.items():
            patient_id = ch_data['patient_id']
            matter_type = ch_data['matter_type']
            features = ch_data['features']
            
            if patient_id not in patient_feature_data:
                patient_feature_data[patient_id] = {'grey': [], 'white': []}
            
            patient_feature_data[patient_id][matter_type].append(features)
        
        # 获取患者列表并排序
        patients = sorted(list(patient_feature_data.keys()))
        
        # 如果患者太多，分批处理
        patient_batches = [patients[i:i+max_patients_per_plot] 
                          for i in range(0, len(patients), max_patients_per_plot)]
        
        # 为每个选定的特征创建可视化
        for feature_name in selected_features:
            print(f"    Creating visualization for feature: {feature_name}")
            
            for batch_idx, patient_batch in enumerate(patient_batches):
                # 计算subplot布局
                n_patients = len(patient_batch)
                n_cols = min(4, n_patients)  # 最多4列
                n_rows = (n_patients + n_cols - 1) // n_cols  # 向上取整
                
                # 创建figure
                fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 4*n_rows))
                
                # 确保axes是2D数组
                if n_rows == 1 and n_cols == 1:
                    axes = np.array([[axes]])
                elif n_rows == 1:
                    axes = axes.reshape(1, -1)
                elif n_cols == 1:
                    axes = axes.reshape(-1, 1)
                
                # 为每个患者创建subplot
                for i, patient_id in enumerate(patient_batch):
                    row = i // n_cols
                    col = i % n_cols
                    ax = axes[row, col]
                    
                    # 提取该患者该特征的数据
                    grey_values = []
                    white_values = []
                    
                    for ch_features in patient_feature_data[patient_id]['grey']:
                        if feature_name in ch_features:
                            grey_values.append(ch_features[feature_name])
                    
                    for ch_features in patient_feature_data[patient_id]['white']:
                        if feature_name in ch_features:
                            white_values.append(ch_features[feature_name])
                    
                    # 绘制分布
                    if grey_values:
                        ax.hist(grey_values, bins=15, alpha=0.6, label='Grey', color='red', density=True)
                    if white_values:
                        ax.hist(white_values, bins=15, alpha=0.6, label='White', color='blue', density=True)
                    
                    # 设置标题和标签
                    ax.set_title(f'{patient_id}\n(G:{len(grey_values)}, W:{len(white_values)})', fontsize=10)
                    ax.set_xlabel(f'{feature_name}', fontsize=8)
                    ax.set_ylabel('Density', fontsize=8)
                    ax.legend(fontsize=8)
                    ax.grid(True, alpha=0.3)
                    
                    # 设置合适的x轴范围
                    all_values = grey_values + white_values
                    if all_values:
                        # 使用1-99百分位数设置范围，避免极端值影响
                        p01, p99 = np.percentile(all_values, [1, 99])
                        if p99 > p01:  # 确保范围有效
                            ax.set_xlim(p01, p99)
                    
                    # 美化刻度
                    ax.tick_params(axis='both', which='major', labelsize=8)
                
                # 隐藏多余的subplots
                for i in range(len(patient_batch), n_rows * n_cols):
                    row = i // n_cols
                    col = i % n_cols
                    axes[row, col].set_visible(False)
                
                # 设置整体标题
                batch_suffix = f"_batch{batch_idx+1}" if len(patient_batches) > 1 else ""
                fig.suptitle(f'Feature Distribution: {feature_name}{batch_suffix}', fontsize=16, y=0.98)
                
                plt.tight_layout()
                plt.subplots_adjust(top=0.93)  # 为suptitle留出空间
                
                # 保存图片
                safe_feature_name = feature_name.replace('/', '_').replace('\\', '_')
                filename = f'feature_{safe_feature_name}{batch_suffix}.png'
                plt.savefig(os.path.join(self.output_folder, filename), 
                           dpi=300, bbox_inches='tight')
                plt.close()
        
        print(f"    ✅ Created visualizations for {len(selected_features)} features")
    
    def create_detailed_feature_visualization(self, features_to_plot=None):
        """
        创建详细的特征可视化（独立方法，可单独调用）
        
        Parameters:
        -----------
        features_to_plot : list, optional
            指定要绘制的特征列表
        """
        print(f"\n📊 Creating detailed feature visualizations...")
        
        if not self.channels_data:
            print(f"❌ No channel data available")
            return
        
        # 如果没有指定特征，使用所有特征
        if features_to_plot is None:
            first_channel = list(self.channels_data.values())[0]
            features_to_plot = list(first_channel['features'].keys())
        
        # 创建特征可视化子文件夹
        feature_viz_folder = os.path.join(self.output_folder, 'feature_distributions')
        os.makedirs(feature_viz_folder, exist_ok=True)
        
        # 临时改变输出文件夹
        original_folder = self.output_folder
        self.output_folder = feature_viz_folder
        
        try:
            self._plot_patient_feature_distributions(
                selected_features=features_to_plot,
                max_patients_per_plot=12
            )
        finally:
            # 恢复原始输出文件夹
            self.output_folder = original_folder
        
        print(f"✅ Feature visualizations saved to: {feature_viz_folder}")
    
    def create_feature_summary_statistics(self):
        """创建特征统计摘要表"""
        print(f"\n📊 Creating feature summary statistics...")
        
        if not self.channels_data:
            print(f"❌ No channel data available")
            return
        
        # 按patient和matter type分组统计
        summary_stats = []
        
        # 获取特征名称
        first_channel = list(self.channels_data.values())[0]
        feature_names = list(first_channel['features'].keys())
        
        # 按patient分组
        patient_groups = {}
        for ch_id, ch_data in self.channels_data.items():
            patient_id = ch_data['patient_id']
            matter_type = ch_data['matter_type']
            
            if patient_id not in patient_groups:
                patient_groups[patient_id] = {'grey': [], 'white': []}
            
            patient_groups[patient_id][matter_type].append(ch_data['features'])
        
        # 为每个patient和matter type计算统计
        for patient_id, matter_data in patient_groups.items():
            for matter_type in ['grey', 'white']:
                if matter_data[matter_type]:  # 如果有数据
                    for feature_name in feature_names:
                        # 收集该特征的所有值
                        feature_values = [ch_features[feature_name] 
                                        for ch_features in matter_data[matter_type]
                                        if feature_name in ch_features]
                        
                        if feature_values:
                            summary_stats.append({
                                'Patient': patient_id,
                                'Matter_Type': matter_type,
                                'Feature': feature_name,
                                'N_Channels': len(feature_values),
                                'Mean': np.mean(feature_values),
                                'Std': np.std(feature_values),
                                'Median': np.median(feature_values),
                                'Min': np.min(feature_values),
                                'Max': np.max(feature_values),
                                'Q25': np.percentile(feature_values, 25),
                                'Q75': np.percentile(feature_values, 75)
                            })
        
        # 保存统计摘要
        if summary_stats:
            summary_df = pd.DataFrame(summary_stats)
            summary_df.to_csv(
                os.path.join(self.output_folder, 'feature_summary_statistics.csv'),
                index=False
            )
            
            print(f"  ✅ Feature summary statistics saved to feature_summary_statistics.csv")
        else:
            print(f"  ❌ No statistics to save")
            
    def _plot_dataset_distribution(self):
        """数据集分布图"""
        # 统计患者分布
        patient_stats = {}
        for ch_data in self.channels_data.values():
            pid = ch_data['patient_id']
            matter = ch_data['matter_type']
            
            if pid not in patient_stats:
                patient_stats[pid] = {'grey': 0, 'white': 0}
            patient_stats[pid][matter] += 1
        
        # 绘图
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
        
        # 患者分布条形图
        patients = list(patient_stats.keys())
        grey_counts = [patient_stats[p]['grey'] for p in patients]
        white_counts = [patient_stats[p]['white'] for p in patients]
        
        x = np.arange(len(patients))
        width = 0.35
        
        ax1.bar(x - width/2, grey_counts, width, label='Grey', color='red', alpha=0.7)
        ax1.bar(x + width/2, white_counts, width, label='White', color='blue', alpha=0.7)
        ax1.set_xlabel('Patients')
        ax1.set_ylabel('Channel Count')
        ax1.set_title('Channel Distribution by Patient')
        ax1.set_xticks(x)
        ax1.set_xticklabels(patients, rotation=45)
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # 总体饼图
        total_grey = sum(grey_counts)
        total_white = sum(white_counts)
        
        ax2.pie([total_grey, total_white], 
               labels=['Grey Matter', 'White Matter'],
               colors=['red', 'blue'], 
               autopct='%1.1f%%')
        ax2.set_title(f'Overall Distribution\n({total_grey + total_white} channels)')
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.output_folder, 'dataset_distribution.png'), dpi=300, bbox_inches='tight')
        plt.close()
    
    def _plot_validation_comparison(self, analysis_results):
        """验证方法对比图"""
        channel_res = analysis_results.get('channel', {}).get('results', {})
        patient_res = analysis_results.get('patient', {}).get('results', {})
        
        if not channel_res or not patient_res:
            return
        
        # 准备数据
        classifiers = list(self.classifiers.keys())
        channel_f1s = [channel_res.get(clf, {}).get('f1', 0) for clf in classifiers]
        patient_f1s = [patient_res.get(clf, {}).get('overall_f1', 0) for clf in classifiers]
        
        # 绘图
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
        
        x = np.arange(len(classifiers))
        width = 0.35
        
        # F1对比
        ax1.bar(x - width/2, channel_f1s, width, label='Channel-Level', alpha=0.7, color='green')
        ax1.bar(x + width/2, patient_f1s, width, label='Patient-Level', alpha=0.7, color='orange')
        ax1.set_xlabel('Classifiers')
        ax1.set_ylabel('F1 Score')
        ax1.set_title('F1 Score Comparison')
        ax1.set_xticks(x)
        ax1.set_xticklabels(classifiers, rotation=45)
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # 差值图
        f1_diff = np.array(channel_f1s) - np.array(patient_f1s)
        colors = ['green' if d > 0 else 'red' for d in f1_diff]
        
        ax2.bar(x, f1_diff, alpha=0.7, color=colors)
        ax2.axhline(y=0, color='black', linestyle='-', alpha=0.5)
        ax2.set_xlabel('Classifiers')
        ax2.set_ylabel('F1 Difference (Channel - Patient)')
        ax2.set_title('F1 Score Difference')
        ax2.set_xticks(x)
        ax2.set_xticklabels(classifiers, rotation=45)
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.output_folder, 'validation_comparison.png'), dpi=300, bbox_inches='tight')
        plt.close()
    
    def _plot_feature_importance(self):
        """特征重要性图"""
        X, y, _ = self._prepare_data()
        if X is None:
            return
        
        # 使用随机森林计算重要性
        rf = RandomForestClassifier(n_estimators=100, random_state=42, class_weight='balanced')
        rf.fit(X, y)
        
        importances = rf.feature_importances_
        feature_names = list(list(self.channels_data.values())[0]['features'].keys())
        
        # 排序
        indices = np.argsort(importances)[::-1]
        
        # 绘图
        plt.figure(figsize=(10, 8))
        plt.barh(range(len(importances)), importances[indices], alpha=0.7)
        plt.yticks(range(len(importances)), [feature_names[i] for i in indices])
        plt.xlabel('Feature Importance')
        plt.title('Feature Importance (Random Forest)')
        plt.gca().invert_yaxis()
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.savefig(os.path.join(self.output_folder, 'feature_importance.png'), dpi=300, bbox_inches='tight')
        plt.close()
    
    def save_results(self, analysis_results, norm_method):
        """保存结果"""
        print(f"\n💾 Saving results...")
        
        # Channel-level结果
        channel_res = analysis_results.get('channel', {}).get('results', {})
        if channel_res:
            channel_data = []
            for clf_name, metrics in channel_res.items():
                channel_data.append({
                    'Classifier': clf_name,
                    'Validation': 'Channel_Level',
                    'F1': metrics['f1'],
                    'Accuracy': metrics['accuracy'],
                    'Balanced_Acc': metrics['balanced_acc'],
                    'ROC_AUC': metrics['roc_auc'],
                    'Precision': metrics['precision'],
                    'Recall': metrics['recall']
                })
            
            pd.DataFrame(channel_data).to_csv(
                os.path.join(self.output_folder, f'channel_results_{norm_method}.csv'), 
                index=False
            )
        
        # Patient-level结果
        patient_res = analysis_results.get('patient', {}).get('results', {})
        if patient_res:
            patient_data = []
            for clf_name, metrics in patient_res.items():
                patient_data.append({
                    'Classifier': clf_name,
                    'Validation': 'Patient_Level',
                    'CV_F1_Mean': metrics['cv_f1_mean'],
                    'CV_F1_Std': metrics['cv_f1_std'],
                    'Overall_F1': metrics['overall_f1'],
                    'Overall_Acc': metrics['overall_acc']
                })
            
            pd.DataFrame(patient_data).to_csv(
                os.path.join(self.output_folder, f'patient_results_{norm_method}.csv'), 
                index=False
            )
        
        # 总结对比
        summary_data = []
        for clf_name in self.classifiers.keys():
            row = {'Classifier': clf_name}
            
            if channel_res and clf_name in channel_res:
                row['Channel_F1'] = channel_res[clf_name]['f1']
                row['Channel_Acc'] = channel_res[clf_name]['accuracy']
            else:
                row['Channel_F1'] = 0
                row['Channel_Acc'] = 0
            
            if patient_res and clf_name in patient_res:
                row['Patient_F1'] = patient_res[clf_name]['overall_f1']
                row['Patient_CV_F1'] = patient_res[clf_name]['cv_f1_mean']
                row['Patient_CV_Std'] = patient_res[clf_name]['cv_f1_std']
            else:
                row['Patient_F1'] = 0
                row['Patient_CV_F1'] = 0
                row['Patient_CV_Std'] = 0
            
            summary_data.append(row)
        
        pd.DataFrame(summary_data).to_csv(
            os.path.join(self.output_folder, f'summary_{norm_method}.csv'), 
            index=False
        )
        
        # 数据集信息
        dataset_info = {
            'total_channels': len(self.channels_data),
            'grey_channels': sum(1 for ch in self.channels_data.values() if ch['label'] == 1),
            'white_channels': sum(1 for ch in self.channels_data.values() if ch['label'] == 0),
            'n_patients': len(set(ch['patient_id'] for ch in self.channels_data.values())),
            'normalization': norm_method
        }
        
        with open(os.path.join(self.output_folder, 'dataset_info.txt'), 'w') as f:
            for key, value in dataset_info.items():
                f.write(f"{key}: {value}\n")
        
        print(f"  ✅ Results saved to {self.output_folder}")
    
    def run_analysis(self, normalization='robust', outlier_clip=True, iqr_factor=1.5, 
                   create_feature_viz=True, features_to_visualize=None):
        """运行完整分析"""
        print(f"🧠 Channel-Level Grey/White Matter Classification")
        print(f"{'='*60}")
        
        # 1. 加载数据
        if not self.load_patients():
            print("❌ Failed to load patients")
            return None
        
        # 2. 提取特征
        if not self.extract_all_features():
            print("❌ Failed to extract features")
            return None
        
        # 3. 标准化
        if not self.normalize_by_patient(normalization):
            print("❌ Failed to normalize")
            return None
        
        # 4. Channel-level验证
        channel_results = self.channel_level_validation(
            outlier_clip=outlier_clip, 
            iqr_factor=iqr_factor
        )
        
        # 5. Patient-level验证
        patient_results = self.patient_level_validation(
            outlier_clip=outlier_clip, 
            iqr_factor=iqr_factor
        )
        
        # 6. 分析结果
        analysis_results = self.analyze_results(channel_results, patient_results)
        
        # 7. 可视化
        self.create_visualizations(analysis_results)
        
        # 8. 详细特征可视化 (可选)
        if create_feature_viz:
            self.create_detailed_feature_visualization(features_to_visualize)
            self.create_feature_summary_statistics()
        
        # 9. 保存结果
        self.save_results(analysis_results, normalization)
        
        # 10. 总结
        print(f"\n{'='*60}")
        print(f"✅ Analysis Complete")
        print(f"{'='*60}")
        
        n_channels = len(self.channels_data)
        n_grey = sum(1 for ch in self.channels_data.values() if ch['label'] == 1)
        n_white = sum(1 for ch in self.channels_data.values() if ch['label'] == 0)
        n_patients = len(set(ch['patient_id'] for ch in self.channels_data.values()))
        
        print(f"📊 Dataset:")
        print(f"   Patients: {n_patients}")
        print(f"   Channels: {n_channels} ({n_grey} grey, {n_white} white)")
        print(f"   Normalization: {normalization} (patient-level)")
        print(f"   Outlier clipping: {'Yes' if outlier_clip else 'No'} (IQR × {iqr_factor})")
        print(f"   Feature visualization: {'Yes' if create_feature_viz else 'No'}")
        
        # 最佳结果
        if analysis_results.get('channel', {}).get('best_clf'):
            channel_best = analysis_results['channel']
            print(f"\n🏆 Best Channel-Level: {channel_best['best_clf']}")
            print(f"   F1 Score: {channel_best['best_f1']:.3f}")
        
        if analysis_results.get('patient', {}).get('best_clf'):
            patient_best = analysis_results['patient']
            print(f"\n🏆 Best Patient-Level: {patient_best['best_clf']}")
            print(f"   F1 Score: {patient_best['best_f1']:.3f}")
        
        print(f"\n📁 Results: {self.output_folder}")
        if create_feature_viz:
            print(f"📁 Feature visualizations: {self.output_folder}/feature_distributions/")
        
        return analysis_results
    
def create_feature_visualizations_only():
    """只创建特征可视化的示例"""
    classifier = ChannelClassifier(
        processed_folder= r"D:\BlcRepo\LabCode\SeizureProp\data\gwbaseline_test",
        output_folder= r"D:\BlcRepo\LabCode\SeizureProp\result\gwclassification_feature_viz_only_test"
    )
    
    # 加载数据和提取特征
    classifier.load_patients()
    classifier.extract_all_features()
    classifier.normalize_by_patient('robust')
    
    # 只创建特征可视化
    classifier.create_detailed_feature_visualization()
    
    # 创建统计摘要
    classifier.create_feature_summary_statistics()
    
    print("✅ Feature visualizations created!")

In [None]:
create_feature_visualizations_only()

In [None]:
classifier = ChannelClassifier(
    processed_folder = r"D:\BlcRepo\LabCode\SeizureProp\data\gwbaseline_alt",
    output_folder = r"D:\BlcRepo\LabCode\SeizureProp\result\multi_patient_results_channel_level"
)

# 运行分析
results = classifier.run_analysis(
    normalization='standard',    # 'robust', 'standard', 'minmax'
    outlier_clip=True,         # 是否进行outlier clipping
    iqr_factor=1.5            # IQR倍数
)

if results:
    print("\n🎯 Quick Summary:")
    
    # Channel-level结果
    if 'channel' in results and results['channel'].get('best_clf'):
        channel_best = results['channel']
        print(f"Channel-Level Best: {channel_best['best_clf']} (F1={channel_best['best_f1']:.3f})")
    
    # Patient-level结果
    if 'patient' in results and results['patient'].get('best_clf'):
        patient_best = results['patient']
        print(f"Patient-Level Best: {patient_best['best_clf']} (F1={patient_best['best_f1']:.3f})")

In [None]:
def plot_psd_analysis_from_classifier(classifier, figsize=(20, 12)):
    """
    从ChannelClassifier输出绘制PSD分析图
    
    Args:
        classifier: ChannelClassifier对象，包含channels_data和patients_data
        figsize: 图片大小
    """
    
    # 定义频段和颜色
    freq_bands = {
        'delta': (0.5, 4, '#90EE90'),      # 浅绿色
        'theta': (4, 8, '#FFA500'),        # 橙色  
        'alpha': (8, 13, '#FFB6C1'),       # 浅粉色
        'beta': (13, 30, '#87CEEB'),       # 天蓝色
        'low_gamma': (30, 60, '#DDA0DD'),  # 梅花色
        'high_gamma': (60, 100, '#F0E68C'), # 卡其色
        'ripple': (100, 200, '#D3D3D3')    # 浅灰色
    }
    
    # 从classifier获取数据
    channels_data = classifier.channels_data
    patients_data = classifier.patients_data
    
    # 获取唯一患者ID
    patient_ids = list(set([ch['patient_id'] for ch in channels_data.values()]))
    patient_ids = sorted(patient_ids)
    
    # 计算每个患者需要的子图数量
    n_patients = len(patient_ids)
    n_cols = min(3, n_patients)  # 最多3列
    n_rows = (n_patients + n_cols - 1) // n_cols  # 向上取整
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
    if n_patients == 1:
        axes = [axes]
    elif n_rows == 1:
        axes = axes.reshape(1, -1)
    
    plt.suptitle('Power Spectral Density Analysis by Patient', fontsize=16, y=0.95)
    
    for idx, patient_id in enumerate(patient_ids):
        row = idx // n_cols
        col = idx % n_cols
        
        if n_rows == 1:
            ax = axes[col] if n_cols > 1 else axes[0]
        else:
            ax = axes[row, col]
        
        # 获取该患者的所有channels
        patient_channels = {k: v for k, v in channels_data.items() 
                          if v['patient_id'] == patient_id}
        
        if not patient_channels:
            ax.set_title(f'{patient_id}\nNo data available')
            ax.axis('off')
            continue
            
        # 分离灰质和白质channels
        grey_channels = {k: v for k, v in patient_channels.items() if v['label'] == 1}
        white_channels = {k: v for k, v in patient_channels.items() if v['label'] == 0}
        
        # 获取该患者的原始数据
        if patient_id in patients_data:
            patient_data = patients_data[patient_id]
            recordings = patient_data['recordings']
            
            # 合并所有recording数据
            all_data = []
            fs = recordings[0]['sampling_rate'] if recordings else 512
            
            for recording in recordings:
                data_segment = recording['neural_data_processed']
                if isinstance(data_segment, list):
                    all_data.extend(data_segment)
                else:
                    all_data.append(data_segment)
            
            if all_data:
                combined_data = np.vstack(all_data)
                
                # 计算每个channel的PSD
                grey_psds = []
                white_psds = []
                freqs = None
                
                # 绘制个体channel的PSD（浅色）
                for ch_id, ch_info in grey_channels.items():
                    electrode_idx = ch_info['electrode_idx']
                    if electrode_idx < combined_data.shape[1]:
                        signal = combined_data[:, electrode_idx]
                        if len(signal) > 0 and not np.all(signal == 0):
                            f, psd = welch(signal, fs=fs, nperseg=min(1024, len(signal)), 
                                         scaling='density')
                            ax.plot(f, psd, color='blue', alpha=0.1, linewidth=0.5)
                            grey_psds.append(psd)
                            if freqs is None:
                                freqs = f
                
                for ch_id, ch_info in white_channels.items():
                    electrode_idx = ch_info['electrode_idx']
                    if electrode_idx < combined_data.shape[1]:
                        signal = combined_data[:, electrode_idx]
                        if len(signal) > 0 and not np.all(signal == 0):
                            f, psd = welch(signal, fs=fs, nperseg=min(1024, len(signal)), 
                                         scaling='density')
                            ax.plot(f, psd, color='red', alpha=0.1, linewidth=0.5)
                            white_psds.append(psd)
                            if freqs is None:
                                freqs = f
                
                # 计算并绘制平均PSD（粗线）
                if grey_psds and freqs is not None:
                    grey_psds = np.vstack(grey_psds)
                    grey_mean_psd = np.mean(grey_psds, axis=0)
                    ax.plot(freqs, grey_mean_psd, color='blue', linewidth=3, 
                           label=f'Grey Matter (n={len(grey_psds)})', alpha=0.8)
                
                if white_psds and freqs is not None:
                    white_psds = np.vstack(white_psds)
                    white_mean_psd = np.mean(white_psds, axis=0)
                    ax.plot(freqs, white_mean_psd, color='red', linewidth=3, 
                           label=f'White Matter (n={len(white_psds)})', alpha=0.8)
        
        # 添加频段背景
        y_min, y_max = ax.get_ylim()
        for band_name, (low_freq, high_freq, color) in freq_bands.items():
            if high_freq <= 200:  # 只显示200Hz以下的频段
                ax.axvspan(low_freq, high_freq, alpha=0.2, color=color)
                # 在顶部添加频段标签
                mid_freq = (low_freq + high_freq) / 2
                ax.text(mid_freq, y_max * 0.9, band_name, 
                       ha='center', va='center', fontsize=8, 
                       bbox=dict(boxstyle='round,pad=0.2', facecolor=color, alpha=0.7))
        
        # 设置坐标轴
        ax.set_yscale('log')
        ax.set_xlabel('Frequency (Hz)')
        ax.set_ylabel('Power Spectral Density (log scale)')
        ax.set_xlim(0, 150)  # 限制到150Hz
        ax.grid(True, alpha=0.3)
        ax.legend(loc='upper right')
        
        # 统计信息
        n_grey = len(grey_channels)
        n_white = len(white_channels)
        ax.set_title(f'{patient_id}\n(G:{n_grey}, W:{n_white})')
    
    # 隐藏多余的子图
    for idx in range(n_patients, n_rows * n_cols):
        row = idx // n_cols
        col = idx % n_cols
        if n_rows == 1:
            axes[col].axis('off')
        else:
            axes[row, col].axis('off')
    
    plt.tight_layout()
    return fig

def plot_combined_patient_psd_from_classifier(classifier, figsize=(15, 8)):
    """
    从ChannelClassifier绘制所有患者合并的PSD对比图
    
    Args:
        classifier: ChannelClassifier对象
        figsize: 图片大小
    """
    
    freq_bands = {
        'delta': (0.5, 4, '#90EE90'),
        'theta': (4, 8, '#FFA500'),
        'alpha': (8, 13, '#FFB6C1'),
        'beta': (13, 30, '#87CEEB'),
        'low_gamma': (30, 60, '#DDA0DD'),
        'high_gamma': (60, 100, '#F0E68C'),
        'ripple': (100, 200, '#D3D3D3')
    }
    
    # 从classifier获取数据
    channels_data = classifier.channels_data
    patients_data = classifier.patients_data
    
    fig, ax = plt.subplots(1, 1, figsize=figsize)
    
    all_grey_psds = []
    all_white_psds = []
    freqs = None
    
    # 获取唯一患者ID
    patient_ids = list(set([ch['patient_id'] for ch in channels_data.values()]))
    
    for patient_id in patient_ids:
        # 获取该患者的channels
        patient_channels = {k: v for k, v in channels_data.items() 
                          if v['patient_id'] == patient_id}
        
        grey_channels = {k: v for k, v in patient_channels.items() if v['label'] == 1}
        white_channels = {k: v for k, v in patient_channels.items() if v['label'] == 0}
        
        # 获取原始数据
        if patient_id in patients_data:
            patient_data = patients_data[patient_id]
            recordings = patient_data['recordings']
            
            all_data = []
            fs = recordings[0]['sampling_rate'] if recordings else 512
            
            for recording in recordings:
                data_segment = recording['neural_data_processed']
                if isinstance(data_segment, list):
                    all_data.extend(data_segment)
                else:
                    all_data.append(data_segment)
            
            if all_data:
                combined_data = np.vstack(all_data)
                
                # 计算灰质channels的PSD
                for ch_id, ch_info in grey_channels.items():
                    electrode_idx = ch_info['electrode_idx']
                    if electrode_idx < combined_data.shape[1]:
                        signal = combined_data[:, electrode_idx]
                        if len(signal) > 0 and not np.all(signal == 0):
                            f, psd = welch(signal, fs=fs, nperseg=min(1024, len(signal)), 
                                         scaling='density')
                            all_grey_psds.append(psd)
                            if freqs is None:
                                freqs = f
                
                # 计算白质channels的PSD
                for ch_id, ch_info in white_channels.items():
                    electrode_idx = ch_info['electrode_idx']
                    if electrode_idx < combined_data.shape[1]:
                        signal = combined_data[:, electrode_idx]
                        if len(signal) > 0 and not np.all(signal == 0):
                            f, psd = welch(signal, fs=fs, nperseg=min(1024, len(signal)), 
                                         scaling='density')
                            all_white_psds.append(psd)
                            if freqs is None:
                                freqs = f
    
    # 计算并绘制总体平均PSD
    if all_grey_psds and freqs is not None:
        all_grey_psds = np.vstack(all_grey_psds)
        grey_mean_psd = np.mean(all_grey_psds, axis=0)
        grey_std_psd = np.std(all_grey_psds, axis=0)
        
        ax.plot(freqs, grey_mean_psd, color='blue', linewidth=3, 
               label=f'Grey Matter (n={len(all_grey_psds)})')
        ax.fill_between(freqs, grey_mean_psd - grey_std_psd, 
                       grey_mean_psd + grey_std_psd, 
                       color='blue', alpha=0.2)
    
    if all_white_psds and freqs is not None:
        all_white_psds = np.vstack(all_white_psds)
        white_mean_psd = np.mean(all_white_psds, axis=0)
        white_std_psd = np.std(all_white_psds, axis=0)
        
        ax.plot(freqs, white_mean_psd, color='red', linewidth=3, 
               label=f'White Matter (n={len(all_white_psds)})')
        ax.fill_between(freqs, white_mean_psd - white_std_psd, 
                       white_mean_psd + white_std_psd, 
                       color='red', alpha=0.2)
    
    # 添加频段背景
    y_min, y_max = ax.get_ylim()
    for band_name, (low_freq, high_freq, color) in freq_bands.items():
        if high_freq <= 150:
            ax.axvspan(low_freq, high_freq, alpha=0.2, color=color)
            mid_freq = (low_freq + high_freq) / 2
            ax.text(mid_freq, y_max * 0.9, band_name, 
                   ha='center', va='center', fontsize=10, 
                   bbox=dict(boxstyle='round,pad=0.2', facecolor=color, alpha=0.7))
    
    ax.set_yscale('log')
    ax.set_xlabel('Frequency (Hz)')
    ax.set_ylabel('Power Spectral Density (log scale)')
    ax.set_xlim(0, 150)
    ax.grid(True, alpha=0.3)
    ax.legend()
    ax.set_title('Average Power Spectral Density - All Patients Combined')
    
    plt.tight_layout()
    return fig

def plot_psd_comparison_grid_from_classifier(classifier, figsize=(20, 15)):
    """
    从ChannelClassifier绘制完整的PSD对比网格
    
    Args:
        classifier: ChannelClassifier对象
        figsize: 图片大小
    """
    
    # 从classifier获取数据
    channels_data = classifier.channels_data
    patients_data = classifier.patients_data
    
    # 获取唯一患者ID
    patient_ids = list(set([ch['patient_id'] for ch in channels_data.values()]))
    patient_ids = sorted(patient_ids)
    n_patients = len(patient_ids)
    
    # 计算网格布局：患者图 + 1个总体图
    n_total_plots = n_patients + 1
    n_cols = min(3, n_total_plots)
    n_rows = (n_total_plots + n_cols - 1) // n_cols
    
    fig = plt.figure(figsize=figsize)
    
    # 创建患者子图
    for idx, patient_id in enumerate(patient_ids):
        ax = plt.subplot(n_rows, n_cols, idx + 1)
        
        # 绘制单个患者的PSD（简化版本）
        patient_channels = {k: v for k, v in channels_data.items() 
                          if v['patient_id'] == patient_id}
        
        grey_channels = {k: v for k, v in patient_channels.items() if v['label'] == 1}
        white_channels = {k: v for k, v in patient_channels.items() if v['label'] == 0}
        
        if patient_id in patients_data:
            patient_data = patients_data[patient_id]
            recordings = patient_data['recordings']
            
            all_data = []
            fs = recordings[0]['sampling_rate'] if recordings else 512
            
            for recording in recordings:
                data_segment = recording['neural_data_processed']
                if isinstance(data_segment, list):
                    all_data.extend(data_segment)
                else:
                    all_data.append(data_segment)
            
            if all_data:
                combined_data = np.vstack(all_data)
                
                grey_psds = []
                white_psds = []
                
                # 计算平均PSD
                for ch_id, ch_info in grey_channels.items():
                    electrode_idx = ch_info['electrode_idx']
                    if electrode_idx < combined_data.shape[1]:
                        signal = combined_data[:, electrode_idx]
                        if len(signal) > 0 and not np.all(signal == 0):
                            f, psd = welch(signal, fs=fs, nperseg=min(1024, len(signal)), 
                                         scaling='density')
                            grey_psds.append(psd)
                
                for ch_id, ch_info in white_channels.items():
                    electrode_idx = ch_info['electrode_idx']
                    if electrode_idx < combined_data.shape[1]:
                        signal = combined_data[:, electrode_idx]
                        if len(signal) > 0 and not np.all(signal == 0):
                            f, psd = welch(signal, fs=fs, nperseg=min(1024, len(signal)), 
                                         scaling='density')
                            white_psds.append(psd)
                
                # 绘制平均线
                if grey_psds:
                    grey_mean_psd = np.mean(np.vstack(grey_psds), axis=0)
                    ax.plot(f, grey_mean_psd, color='blue', linewidth=2, 
                           label=f'Grey (n={len(grey_psds)})')
                
                if white_psds:
                    white_mean_psd = np.mean(np.vstack(white_psds), axis=0)
                    ax.plot(f, white_mean_psd, color='red', linewidth=2, 
                           label=f'White (n={len(white_psds)})')
        
        ax.set_yscale('log')
        ax.set_xlim(0, 150)
        ax.grid(True, alpha=0.3)
        ax.legend(fontsize=8)
        ax.set_title(f'{patient_id}')
        
        if idx >= (n_rows - 1) * n_cols:  # 最后一行
            ax.set_xlabel('Frequency (Hz)')
        if idx % n_cols == 0:  # 第一列
            ax.set_ylabel('PSD (log)')
    
    # 创建总体对比图
    ax_combined = plt.subplot(n_rows, n_cols, n_patients + 1)
    
    # 总体PSD计算
    all_grey_psds = []
    all_white_psds = []
    
    for patient_id in patient_ids:
        patient_channels = {k: v for k, v in channels_data.items() 
                          if v['patient_id'] == patient_id}
        
        grey_channels = {k: v for k, v in patient_channels.items() if v['label'] == 1}
        white_channels = {k: v for k, v in patient_channels.items() if v['label'] == 0}
        
        if patient_id in patients_data:
            patient_data = patients_data[patient_id]
            recordings = patient_data['recordings']
            
            all_data = []
            fs = recordings[0]['sampling_rate'] if recordings else 512
            
            for recording in recordings:
                data_segment = recording['neural_data_processed']
                if isinstance(data_segment, list):
                    all_data.extend(data_segment)
                else:
                    all_data.append(data_segment)
            
            if all_data:
                combined_data = np.vstack(all_data)
                
                for ch_id, ch_info in grey_channels.items():
                    electrode_idx = ch_info['electrode_idx']
                    if electrode_idx < combined_data.shape[1]:
                        signal = combined_data[:, electrode_idx]
                        if len(signal) > 0 and not np.all(signal == 0):
                            f, psd = welch(signal, fs=fs, nperseg=min(1024, len(signal)), 
                                         scaling='density')
                            all_grey_psds.append(psd)
                
                for ch_id, ch_info in white_channels.items():
                    electrode_idx = ch_info['electrode_idx']
                    if electrode_idx < combined_data.shape[1]:
                        signal = combined_data[:, electrode_idx]
                        if len(signal) > 0 and not np.all(signal == 0):
                            f, psd = welch(signal, fs=fs, nperseg=min(1024, len(signal)), 
                                         scaling='density')
                            all_white_psds.append(psd)
    
    # 绘制总体平均
    if all_grey_psds:
        grey_mean_psd = np.mean(np.vstack(all_grey_psds), axis=0)
        ax_combined.plot(f, grey_mean_psd, color='blue', linewidth=3, 
                        label=f'Grey Matter (n={len(all_grey_psds)})')
    
    if all_white_psds:
        white_mean_psd = np.mean(np.vstack(all_white_psds), axis=0)
        ax_combined.plot(f, white_mean_psd, color='red', linewidth=3, 
                        label=f'White Matter (n={len(all_white_psds)})')
    
    ax_combined.set_yscale('log')
    ax_combined.set_xlim(0, 150)
    ax_combined.grid(True, alpha=0.3)
    ax_combined.legend()
    ax_combined.set_title('All Patients Combined')
    ax_combined.set_xlabel('Frequency (Hz)')
    if (n_patients + 1 - 1) % n_cols == 0:
        ax_combined.set_ylabel('PSD (log)')
    
    plt.suptitle('Power Spectral Density Analysis', fontsize=16)
    plt.tight_layout()
    
    return fig

In [None]:
# 方法1: 详细的个体患者PSD图
fig1 = plot_psd_analysis_from_classifier(classifier, figsize=(20, 12))
fig1.savefig(r"D:\BlcRepo\LabCode\SeizureProp\result\multi_patient_results_channel_level\psd_analysis_individual_patients.png", dpi=300, bbox_inches='tight')

# 方法2: 所有患者合并的PSD对比图
fig2 = plot_combined_patient_psd_from_classifier(classifier, figsize=(15, 8))
fig2.savefig(r"D:\BlcRepo\LabCode\SeizureProp\result\multi_patient_results_channel_level\psd_analysis_combined_patients.png", dpi=300, bbox_inches='tight')

# 方法3: 完整的网格布局（推荐）
fig3 = plot_psd_comparison_grid_from_classifier(classifier, figsize=(20, 15))
fig3.savefig(r"D:\BlcRepo\LabCode\SeizureProp\result\multi_patient_results_channel_level\psd_analysis_comparison_grid.png", dpi=300, bbox_inches='tight')

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.patches import Rectangle

def plot_classification_results_comprehensive(channel_results, patient_results, 
                                            save_path=None, figsize=(16, 12)):
    """
    创建分类结果的综合可视化图表
    
    Args:
        channel_results: Channel-level结果字典
        patient_results: Patient-level结果（包含cv_results和all_predictions）
        save_path: 保存路径（可选）
        figsize: 图片尺寸
    
    Returns:
        fig: matplotlib figure对象
    """
    
    # 解析patient_results
    if isinstance(patient_results, tuple) and len(patient_results) == 2:
        cv_results, all_predictions = patient_results
    else:
        cv_results = patient_results
        all_predictions = None
    
    # 设置样式
    plt.style.use('seaborn-v0_8')
    
    # 创建2x2子图布局
    fig = plt.figure(figsize=figsize)
    
    # 定义颜色
    colors = {
        'channel': '#3498db',  # 蓝色
        'accuracy': '#f39c12',  # 橙色
        'roc': '#2ecc71',      # 绿色
        'patient': '#9b59b6'   # 紫色
    }
    
    # 获取分类器名称
    classifiers = list(channel_results.keys()) if channel_results else []
    
    # 1. F1 Score对比（左上）
    ax1 = plt.subplot(2, 2, 1)
    plot_f1_comparison(ax1, channel_results, cv_results, all_predictions, classifiers, colors)
    
    # 2. Accuracy对比（右上）
    ax2 = plt.subplot(2, 2, 2)
    plot_accuracy_comparison(ax2, channel_results, classifiers, colors)
    
    # 3. ROC AUC对比（左下）
    ax3 = plt.subplot(2, 2, 3)
    plot_roc_comparison(ax3, channel_results, classifiers, colors)
    
    # 4. Patient-wise结果（右下）
    ax4 = plt.subplot(2, 2, 4)
    plot_patient_wise_results(ax4, cv_results, colors)
    
    plt.tight_layout(pad=3.0)
    
    # 保存图片
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
        print(f"📊 Results visualization saved to: {save_path}")
    
    return fig

def plot_f1_comparison(ax, channel_results, cv_results, all_predictions, classifiers, colors):
    """绘制F1 Score对比图"""
    
    # 提取Channel-level F1 scores
    channel_f1s = []
    channel_stds = []
    
    for clf in classifiers:
        if clf in channel_results:
            channel_f1s.append(channel_results[clf]['f1'])
            channel_stds.append(0)  # Channel-level没有std
        else:
            channel_f1s.append(0)
            channel_stds.append(0)
    
    # 提取Patient-level F1 scores
    patient_f1s = []
    patient_stds = []
    
    for clf in classifiers:
        if clf in cv_results and len(cv_results[clf]) > 0:
            f1_values = [fold['f1'] for fold in cv_results[clf]]
            patient_f1s.append(np.mean(f1_values))
            patient_stds.append(np.std(f1_values))
        else:
            patient_f1s.append(0)
            patient_stds.append(0)
    
    # 绘制条形图
    x = np.arange(len(classifiers))
    width = 0.35
    
    bars1 = ax.bar(x - width/2, channel_f1s, width, 
                   label='Channel-Level', color=colors['channel'], alpha=0.8,
                   yerr=channel_stds, capsize=5)
    
    bars2 = ax.bar(x + width/2, patient_f1s, width,
                   label='Patient-Level', color=colors['patient'], alpha=0.8,
                   yerr=patient_stds, capsize=5)
    
    ax.set_xlabel('Classifiers')
    ax.set_ylabel('F1 Score')
    ax.set_title('F1 Score by Classifier (with std)')
    ax.set_xticks(x)
    ax.set_xticklabels(classifiers, rotation=45, ha='right')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_ylim(0, max(max(channel_f1s), max(patient_f1s)) * 1.1)

def plot_accuracy_comparison(ax, channel_results, classifiers, colors):
    """绘制Accuracy对比图"""
    
    accuracies = []
    for clf in classifiers:
        if clf in channel_results:
            accuracies.append(channel_results[clf]['accuracy'])
        else:
            accuracies.append(0)
    
    bars = ax.bar(classifiers, accuracies, color=colors['accuracy'], alpha=0.8)
    
    ax.set_xlabel('Classifiers')
    ax.set_ylabel('Accuracy')
    ax.set_title('Accuracy by Classifier')
    ax.tick_params(axis='x', rotation=45)
    ax.grid(True, alpha=0.3)
    ax.set_ylim(0, 1)
    
    # 添加数值标签
    for bar, acc in zip(bars, accuracies):
        if acc > 0:
            ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                   f'{acc:.3f}', ha='center', va='bottom', fontsize=9)

def plot_roc_comparison(ax, channel_results, classifiers, colors):
    """绘制ROC AUC对比图"""
    
    roc_aucs = []
    for clf in classifiers:
        if clf in channel_results:
            roc_aucs.append(channel_results[clf]['roc_auc'])
        else:
            roc_aucs.append(0.5)  # 默认值
    
    bars = ax.bar(classifiers, roc_aucs, color=colors['roc'], alpha=0.8)
    
    ax.set_xlabel('Classifiers')
    ax.set_ylabel('ROC AUC')
    ax.set_title('ROC AUC by Classifier')
    ax.tick_params(axis='x', rotation=45)
    ax.grid(True, alpha=0.3)
    ax.set_ylim(0, 1)
    
    # 添加基准线
    ax.axhline(y=0.5, color='red', linestyle='--', alpha=0.5, label='Random')
    ax.legend()

def plot_patient_wise_results(ax, cv_results, colors):
    """绘制Patient-wise结果"""
    
    # 找到最佳分类器（这里以Naive Bayes为例，你可以根据实际情况调整）
    best_clf = None
    best_score = 0
    
    for clf_name, folds in cv_results.items():
        if len(folds) > 0:
            avg_f1 = np.mean([fold['f1'] for fold in folds])
            if avg_f1 > best_score:
                best_score = avg_f1
                best_clf = clf_name
    
    if best_clf and best_clf in cv_results:
        # 提取每个患者（fold）的结果
        folds = cv_results[best_clf]
        patients = [fold['test_patient'] for fold in folds]
        f1_scores = [fold['f1'] for fold in folds]
        
        bars = ax.bar(patients, f1_scores, color=colors['patient'], alpha=0.8)
        
        ax.set_xlabel('Patients')
        ax.set_ylabel('F1 Score')
        ax.set_title(f'Patient-wise F1 Score ({best_clf})')
        ax.tick_params(axis='x', rotation=45)
        ax.grid(True, alpha=0.3)
        ax.set_ylim(0, 1)
        
        # 添加平均线
        mean_f1 = np.mean(f1_scores)
        ax.axhline(y=mean_f1, color='red', linestyle='--', alpha=0.7, 
                  label=f'Mean: {mean_f1:.3f}')
        ax.legend()
    else:
        ax.text(0.5, 0.5, 'No patient-wise results available', 
               ha='center', va='center', transform=ax.transAxes)
        ax.set_title('Patient-wise F1 Score')

def plot_classification_results_simple(analysis_results, save_path=None, figsize=(15, 10)):
    """
    简化版本的结果可视化（适用于analyzer.analyze_results的输出）
    
    Args:
        analysis_results: analyzer.analyze_results()的输出
        save_path: 保存路径
        figsize: 图片尺寸
    """
    
    channel_results = analysis_results.get('channel', {}).get('results', {})
    patient_summary = analysis_results.get('patient', {}).get('results', {})
    
    # 创建2x2子图
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=figsize)
    
    classifiers = list(channel_results.keys())
    colors = ['#3498db', '#f39c12', '#2ecc71', '#9b59b6', '#e74c3c', '#1abc9c', '#f1c40f']
    
    # 1. F1 Score对比
    if channel_results and patient_summary:
        channel_f1s = [channel_results[clf]['f1'] for clf in classifiers]
        patient_f1s = [patient_summary[clf]['overall_f1'] if clf in patient_summary else 0 
                      for clf in classifiers]
        patient_stds = [patient_summary[clf]['cv_f1_std'] if clf in patient_summary else 0 
                       for clf in classifiers]
        
        x = np.arange(len(classifiers))
        width = 0.35
        
        ax1.bar(x - width/2, channel_f1s, width, label='Channel-Level', 
               color='#3498db', alpha=0.8)
        ax1.bar(x + width/2, patient_f1s, width, label='Patient-Level', 
               color='#9b59b6', alpha=0.8, yerr=patient_stds, capsize=5)
        
        ax1.set_xlabel('Classifiers')
        ax1.set_ylabel('F1 Score')
        ax1.set_title('F1 Score by Classifier (with std)')
        ax1.set_xticks(x)
        ax1.set_xticklabels(classifiers, rotation=45, ha='right')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
    
    # 2. Accuracy
    if channel_results:
        accuracies = [channel_results[clf]['accuracy'] for clf in classifiers]
        bars = ax2.bar(classifiers, accuracies, color='#f39c12', alpha=0.8)
        ax2.set_xlabel('Classifiers')
        ax2.set_ylabel('Accuracy')
        ax2.set_title('Accuracy by Classifier')
        ax2.tick_params(axis='x', rotation=45)
        ax2.grid(True, alpha=0.3)
        ax2.set_ylim(0, 1)
    
    # 3. ROC AUC
    if channel_results:
        roc_aucs = [channel_results[clf]['roc_auc'] for clf in classifiers]
        ax3.bar(classifiers, roc_aucs, color='#2ecc71', alpha=0.8)
        ax3.set_xlabel('Classifiers')
        ax3.set_ylabel('ROC AUC')
        ax3.set_title('ROC AUC by Classifier')
        ax3.tick_params(axis='x', rotation=45)
        ax3.grid(True, alpha=0.3)
        ax3.set_ylim(0, 1)
        ax3.axhline(y=0.5, color='red', linestyle='--', alpha=0.5)
    
    # 4. 模型对比雷达图
    if channel_results:
        metrics = ['F1', 'Accuracy', 'Precision', 'Recall', 'ROC AUC']
        
        # 选择最佳模型显示
        best_clf = max(channel_results.keys(), 
                      key=lambda x: channel_results[x]['f1'])
        
        values = [
            channel_results[best_clf]['f1'],
            channel_results[best_clf]['accuracy'],
            channel_results[best_clf]['precision'],
            channel_results[best_clf]['recall'],
            channel_results[best_clf]['roc_auc']
        ]
        
        # 简单的条形图代替雷达图
        ax4.barh(metrics, values, color='#e74c3c', alpha=0.8)
        ax4.set_xlabel('Score')
        ax4.set_title(f'Best Model Performance ({best_clf})')
        ax4.set_xlim(0, 1)
        ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
        print(f"📊 Results visualization saved to: {save_path}")
    
    return fig

def create_summary_table(analysis_results, save_path=None):
    """
    创建结果汇总表
    
    Args:
        analysis_results: 分析结果字典
        save_path: 保存路径
    
    Returns:
        pd.DataFrame: 结果汇总表
    """
    
    channel_results = analysis_results.get('channel', {}).get('results', {})
    patient_results = analysis_results.get('patient', {}).get('results', {})
    
    summary_data = []
    
    for clf_name in channel_results.keys():
        row = {
            'Classifier': clf_name,
            'Channel_F1': channel_results[clf_name]['f1'],
            'Channel_Accuracy': channel_results[clf_name]['accuracy'],
            'Channel_Precision': channel_results[clf_name]['precision'],
            'Channel_Recall': channel_results[clf_name]['recall'],
            'Channel_ROC_AUC': channel_results[clf_name]['roc_auc'],
            'Channel_Balanced_Acc': channel_results[clf_name]['balanced_acc']
        }
        
        if clf_name in patient_results:
            row.update({
                'Patient_F1_Mean': patient_results[clf_name]['cv_f1_mean'],
                'Patient_F1_Std': patient_results[clf_name]['cv_f1_std'],
                'Patient_Overall_F1': patient_results[clf_name]['overall_f1'],
                'Patient_Overall_Acc': patient_results[clf_name]['overall_acc']
            })
        
        summary_data.append(row)
    
    summary_df = pd.DataFrame(summary_data)
    
    if save_path:
        summary_df.to_csv(save_path, index=False)
        print(f"📋 Summary table saved to: {save_path}")
    
    return summary_df

In [None]:
channel_results = classifier.channel_level_validation()
patient_results = classifier.patient_level_validation()
analysis_results = classifier.analyze_results(channel_results, patient_results)

fig2 = plot_classification_results_simple(
    analysis_results,
    save_path='classification_results_simple.png'
)
fig = plot_classification_results_comprehensive(
    channel_results, patient_results,
    save_path='detailed_results.png'
)
summary_table = create_summary_table(
    analysis_results,
    save_path='results_summary.csv'
)

print(summary_table)