In [None]:
import numpy as np
# At the start of your notebook
from IPython.display import clear_output
import gc
import os

# After heavy computations
clear_output(wait=True)
gc.collect()

In [None]:
RESULT_FOLDER = "result"
MODEL_FOLDER = "checkpoints"
model_names = ['Transformer', 'Wavenet', 'ResNet']  # 'CNN1D', 'Wavenet', 'S4', 'Resnet', 'Transformer'

In [None]:
from steps import extract_sEEG_features
from datasetConstruct import load_seizure_across_patients, load_single_seizure

dataset = load_seizure_across_patients(data_folder='data')

for seizure in dataset:
    seizure_new = extract_sEEG_features(seizure, sampling_rate=seizure.samplingRate)
    
dataset = load_seizure_across_patients(data_folder='data')

In [None]:
import torch
from torch.utils.data import DataLoader, random_split
from datasetConstruct import CustomDataset
from models import (
    universal_train_model, 
    universal_hyperparameter_search,
    create_model_with_best_params,
    universal_evaluate_model
)

# 导入models.py中的模型
from models import ResNet, Wavenet, CustomTransformerSeizurePredictor

In [None]:
X = []
y = []

for patient_idx, seizure_obj in enumerate(dataset):
    print(f"Processing patient {patient_idx + 1}/{len(dataset)}: {seizure_obj.patNo}")
    
    # 取transformed特征
    ictal = seizure_obj.ictal_transformed  # shape: (segments, channels, features, time_steps)
    interictal = seizure_obj.interictal_transformed  # shape: (segments, channels, features, time_steps)
    
    print(f"  Ictal shape: {ictal.shape}")
    print(f"  Interictal shape: {interictal.shape}")
    
    # 获取seizure相关信息
    seizure_segment_counts = seizure_obj.seizure_segment_counts  # {SZ1: 84, SZ2: 120, ...}
    seizure_specific_channels = getattr(seizure_obj, 'seizure_specific_channels', {})  # {SZ1: [0,1,2,3,4,5], ...}
    
    # 检查是否有specific channels标注
    has_specific_channels = any(len(channels) > 0 for channels in seizure_specific_channels.values())
    
    if has_specific_channels:
        print(f"  Using specific channel annotations")
        # 有specific channels标注的情况
        current_segment = 0
        for seizure_name, segment_count in seizure_segment_counts.items():
            seizure_channels = seizure_specific_channels.get(seizure_name, [])
            
            # 处理这个seizure的segments
            for seg_idx in range(segment_count):
                if current_segment < ictal.shape[0]:  # 确保不超出范围
                    segment_data = ictal[current_segment]
                    
                    # 遍历所有channels
                    for ch_idx in range(segment_data.shape[0]):
                        channel_data = segment_data[ch_idx]
                        
                        # 检查这个channel是否在这个seizure的特定channels中
                        if ch_idx in seizure_channels:
                            # 这个channel在这个seizure中是ictal
                            X.append(channel_data.transpose(1, 0))  # (features, time_steps)
                            y.append([0.0, 1.0])  # [non-seizure_prob, seizure_prob]
                        else:
                            # 这个channel在这个seizure中不是ictal，当作interictal处理
                            X.append(channel_data.transpose(1, 0))
                            y.append([1.0, 0.0])  # [non-seizure_prob, seizure_prob]
                    
                    current_segment += 1
    else:
        print(f"  No specific channel annotations found, treating all ictal as seizure")
        # 没有specific channels标注的情况，所有ictal都当作seizure
        for seg in ictal:
            for ch in seg:
                X.append(ch.transpose(1, 0))  # (features, time_steps)
                y.append([0.0, 1.0])  # [non-seizure_prob, seizure_prob]
    
    # 处理interictal数据 - 所有interictal都是non-seizure
    for seg in interictal:
        for ch in seg:
            X.append(ch.transpose(1, 0))
            y.append([1.0, 0.0])  # [non-seizure_prob, seizure_prob]

X = np.array(X)
y = np.array(y, dtype=np.float32)  # 概率分布标签

print(f"\nFinal dataset statistics:")
print(f"Total samples: {len(X)}")
print(f"X shape: {X.shape}")
print(f"y shape: {y.shape}")
print(f"Seizure samples: {np.sum(y[:, 1] == 1.0)}")
print(f"Non-seizure samples: {np.sum(y[:, 0] == 1.0)}")

# 构建Dataset和DataLoader
dataset_torch = CustomDataset(X, y)  # 使用概率分布标签
print(f"Dataset size: {len(dataset_torch)}")

train_size = int(0.8 * len(dataset_torch))
val_size = len(dataset_torch) - train_size
train_dataset, val_dataset = random_split(dataset_torch, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1024)

print(f"Train samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

In [None]:
input_dim = X.shape[1]  # features
output_dim = 2  # seizure/non-seizure
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# 要比较的模型
models_to_compare = [
    CustomTransformerSeizurePredictor,  # 你的Transformer模型
    ResNet,
    Wavenet
]

results = {}

for model_class in models_to_compare:
    model_name = model_class.__name__
    print(f"处理模型: {model_name}")
    
    # 1. 超参数搜索
    best_params, study = universal_hyperparameter_search(
        model_class=model_class,
        train_loader=train_loader,
        val_loader=val_loader,
        input_dim=input_dim,
        output_dim=2,
        n_trials=20,
        device='cuda',
        model_folder=f'checkpoints/{model_name}'
    )
    
    # 2. 使用最佳参数创建最终模型
    final_model = create_model_with_best_params(
        model_class, best_params, input_dim, 2
    )
    
    # 3. 训练最终模型
    train_losses, val_losses, val_accuracies = universal_train_model(
        model=final_model,
        train_loader=train_loader,
        val_loader=val_loader,
        save_location=f'checkpoints/{model_name}_final',
        epochs=50,
        device='cuda'
    )
    
    # 4. 保存结果
    results[model_name] = {
        'best_params': best_params,
        'best_accuracy': max(val_accuracies),
        'model': final_model
    }

# 比较结果
for model_name, result in results.items():
    print(f"{model_name}: 准确率 {result['best_accuracy']:.4f}")