In [None]:
import numpy as np
# At the start of your notebook
from IPython.display import clear_output
import gc

# After heavy computations
clear_output(wait=True)
gc.collect()

In [None]:
RESULT_FOLDER = "result"
MODEL_FOLDER = "checkpoints"
model_names = ['Transformer','CNN1D', 'Wavenet', 'ResNet']  # 'CNN1D', 'Wavenet', 'S4', 'Resnet', 'Transformer'

In [None]:
# from steps import extract_sEEG_features
from datasetConstruct import load_seizure_across_patients, load_single_seizure

dataset = load_seizure_across_patients(data_folder='data')

# for seizure in dataset:
#     seizure_new = extract_sEEG_features(seizure, sampling_rate=seizure.samplingRate)

In [None]:
import torch
from torch.utils.data import DataLoader, random_split
from models import CustomTransformerSeizurePredictor, evaluate_model
from datasetConstruct import CustomDataset

In [None]:
X = []
y = []

for patient_idx, seizure_obj in enumerate(dataset):
    print(f"Processing patient {patient_idx + 1}/{len(dataset)}: {seizure_obj.patNo}")
    
    # 取transformed特征
    ictal = seizure_obj.ictal_transformed  # shape: (segments, channels, features, time_steps)
    interictal = seizure_obj.interictal_transformed  # shape: (segments, channels, features, time_steps)
    
    print(f"  Ictal shape: {ictal.shape}")
    print(f"  Interictal shape: {interictal.shape}")
    
    # 获取seizure相关信息
    seizure_segment_counts = seizure_obj.seizure_segment_counts  # {SZ1: 84, SZ2: 120, ...}
    seizure_specific_channels = getattr(seizure_obj, 'seizure_specific_channels', {})  # {SZ1: [0,1,2,3,4,5], ...}
    
    # 检查是否有specific channels标注
    has_specific_channels = any(len(channels) > 0 for channels in seizure_specific_channels.values())
    
    if has_specific_channels:
        print(f"  Using specific channel annotations")
        # 有specific channels标注的情况
        current_segment = 0
        for seizure_name, segment_count in seizure_segment_counts.items():
            seizure_channels = seizure_specific_channels.get(seizure_name, [])
            
            # 处理这个seizure的segments
            for seg_idx in range(segment_count):
                if current_segment < ictal.shape[0]:  # 确保不超出范围
                    segment_data = ictal[current_segment]
                    
                    # 遍历所有channels
                    for ch_idx in range(segment_data.shape[0]):
                        channel_data = segment_data[ch_idx]
                        
                        # 检查这个channel是否在这个seizure的特定channels中
                        if ch_idx in seizure_channels:
                            # 这个channel在这个seizure中是ictal
                            X.append(channel_data.transpose(1, 0))  # (features, time_steps)
                            y.append([0.0, 1.0])  # [non-seizure_prob, seizure_prob]
                        else:
                            # 这个channel在这个seizure中不是ictal，当作interictal处理
                            X.append(channel_data.transpose(1, 0))
                            y.append([1.0, 0.0])  # [non-seizure_prob, seizure_prob]
                    
                    current_segment += 1
    else:
        print(f"  No specific channel annotations found, treating all ictal as seizure")
        # 没有specific channels标注的情况，所有ictal都当作seizure
        for seg in ictal:
            for ch in seg:
                X.append(ch.transpose(1, 0))  # (features, time_steps)
                y.append([0.0, 1.0])  # [non-seizure_prob, seizure_prob]
    
    # 处理interictal数据 - 所有interictal都是non-seizure
    for seg in interictal:
        for ch in seg:
            X.append(ch.transpose(1, 0))
            y.append([1.0, 0.0])  # [non-seizure_prob, seizure_prob]

X = np.array(X)
y = np.array(y, dtype=np.float32)  # 概率分布标签

print(f"\nFinal dataset statistics:")
print(f"Total samples: {len(X)}")
print(f"X shape: {X.shape}")
print(f"y shape: {y.shape}")
print(f"Seizure samples: {np.sum(y[:, 1] == 1.0)}")
print(f"Non-seizure samples: {np.sum(y[:, 0] == 1.0)}")

# 构建Dataset和DataLoader
dataset_torch = CustomDataset(X, y)  # 使用概率分布标签
print(f"Dataset size: {len(dataset_torch)}")

train_size = int(0.8 * len(dataset_torch))
val_size = len(dataset_torch) - train_size
train_dataset, val_dataset = random_split(dataset_torch, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=2048, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2048)

print(f"Train samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

In [None]:
input_dim = X.shape[1]  # features
output_dim = 2  # seizure/non-seizure
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = CustomTransformerSeizurePredictor(input_dim=input_dim, output_dim=output_dim, device=device)
# model.random_init()

In [None]:
# Cell 7 (KL divergence评估函数)
def evaluate_model_kl(model, dataloader, device='cuda'):
    """
    专门用于KL divergence的评估函数
    """
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    
    criterion = torch.nn.KLDivLoss(reduction='batchmean')
    
    with torch.no_grad():
        for batch_idx, (data, labels) in enumerate(dataloader):
            data = data.to(device)
            labels = labels.to(device)  # 保持浮点型
            
            outputs = model(data)
            
            # 确保输出是有效的概率分布
            outputs = torch.clamp(outputs, min=1e-8, max=1.0)
            outputs = outputs / outputs.sum(dim=1, keepdim=True)
            
            # 对输出取log
            log_outputs = torch.log(outputs)
            
            # 计算KL divergence损失
            loss = criterion(log_outputs, labels)
            
            if not (torch.isnan(loss) or torch.isinf(loss)):
                total_loss += loss.item()
                
                # 计算准确率
                predicted = torch.argmax(outputs, dim=1)
                true_labels = torch.argmax(labels, dim=1)
                
                total += labels.size(0)
                correct += (predicted == true_labels).sum().item()
    
    if total > 0:
        avg_loss = total_loss / len(dataloader)
        accuracy = 100 * correct / total
        return avg_loss, accuracy
    else:
        return float('inf'), 0.0

In [None]:
# 添加hyperparameter search相关的函数
import optuna
import os
import torch.optim as optim
def train_model_for_optuna(
        model,
        train_loader,
        val_loader,
        device='cuda',
        epochs=15,
        patience=5,
        save_location='checkpoints'
):
    """
    专门用于Optuna hyperparameter search的Transformer训练函数
    """
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-4)
    criterion = torch.nn.KLDivLoss(reduction='batchmean')
    
    train_losses = []
    val_losses = []
    val_accuracies = []
    best_val_loss = float('inf')
    patience_counter = 0
    
    for epoch in range(epochs):
        # 训练阶段
        model.train()
        total_loss = 0
        num_batches = 0
        
        for batch_idx, batch in enumerate(train_loader):
            data, labels = batch  
            data = data.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(data)
            
            # 确保输出是有效的概率分布
            outputs = torch.clamp(outputs, min=1e-8, max=1.0)
            outputs = outputs / outputs.sum(dim=1, keepdim=True)
            
            # 对输出取log
            log_outputs = torch.log(outputs)
            
            # 检查数值稳定性
            if torch.isnan(log_outputs).any() or torch.isinf(log_outputs).any():
                continue
                
            loss = criterion(log_outputs, labels)
            
            # 检查损失值
            if torch.isnan(loss) or torch.isinf(loss):
                continue
                
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            total_loss += loss.item()
            num_batches += 1
        
        if num_batches > 0:
            avg_train_loss = total_loss / num_batches
            train_losses.append(avg_train_loss)
        else:
            avg_train_loss = float('inf')
            train_losses.append(avg_train_loss)
        
        # 验证阶段
        val_loss, val_acc = evaluate_model_kl(model, val_loader, device=device)
        val_losses.append(val_loss)
        val_accuracies.append(val_acc)
        
        # Early stopping for optuna
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1
        
        if patience_counter >= patience:
            break
            
        # Save the checkpoint
        if not os.path.exists(save_location):
            os.makedirs(save_location)
        checkpoint_path = os.path.join(save_location, f"Transformer_epoch_{epoch+1}.pth")
        torch.save(model.state_dict(), checkpoint_path)
    
    return train_losses, val_losses, val_accuracies

def hyperparameter_search_for_transformer(
        train_loader,
        val_loader,
        channels: int,
        time_steps: int,
        n_trials: int = 20,
        device: str = 'cuda',
        model_folder: str = 'checkpoints'
):
    """
    专门用于Transformer的hyperparameter search函数
    """
    
    def objective(trial):
        # 定义搜索空间
        lr = trial.suggest_float('lr', 1e-6, 1e-3, log=True)
        d_model = trial.suggest_categorical('d_model', [512, 1024, 2048])
        n_heads = trial.suggest_categorical('n_heads', [8, 16, 32])
        n_layers = trial.suggest_categorical('n_layers', [2,3,4])
        d_ff = trial.suggest_categorical('d_ff', [2048, 4096, 8192])
        dropout = trial.suggest_float('dropout', 0.1, 0.5)
        weight_decay = trial.suggest_float('weight_decay', 1e-6, 1e-4, log=True)
        
        # 创建模型
        model = CustomTransformerSeizurePredictor(
            input_dim=channels,
            output_dim=2,
            d_model=d_model,
            n_heads=n_heads,
            n_layers=n_layers,
            d_ff=d_ff,
            dropout=dropout,
            max_seq_len=time_steps,
            lr=lr,
            weight_decay=weight_decay,
            device=device
        )
        
        try:
            # Create a temporary directory for checkpoints
            temp_dir = os.path.join(model_folder, f"optuna_trial_{trial.number}")
            os.makedirs(temp_dir, exist_ok=True)
            
            # 训练模型
            train_losses, val_losses, val_accuracies = train_model_for_optuna(
                model=model,
                train_loader=train_loader,
                val_loader=val_loader,
                device=device,
                epochs=15,
                patience=5,
                save_location=temp_dir
            )
            
            # 返回最佳验证准确率
            if val_accuracies:
                best_acc = max(val_accuracies)
                return best_acc
            else:
                return 0.0
                
        except Exception as e:
            print(f"Trial failed: {str(e)}")
            return 0.0
    
    # 创建Optuna study
    study = optuna.create_study(
        direction="maximize",
        sampler=optuna.samplers.TPESampler(seed=42),
        pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=5)
    )
    
    # 运行优化
    study.optimize(objective, n_trials=n_trials)
    
    # 打印最佳结果
    print(f"\nBest trial for Transformer:")
    trial = study.best_trial
    print(f"  Value: {trial.value}")
    print(f"  Params: {trial.params}")
    
    return trial.params, study

In [None]:
# 执行hyperparameter search
print("Starting hyperparameter search for Transformer...")

model_folder = 'checkpoints'

# Create model folder if it doesn't exist
os.makedirs(model_folder, exist_ok=True)

# 获取数据维度
channels = X.shape[1]
time_steps = X.shape[2]

# 执行hyperparameter search
best_params, study = hyperparameter_search_for_transformer(
    train_loader=train_loader,
    val_loader=val_loader,
    channels=channels,
    time_steps=time_steps,
    n_trials=20,
    device=device
)

print(f"Best hyperparameters found: {best_params}")

In [None]:
# 使用最佳参数创建和训练模型
print("Training Transformer with best hyperparameters...")

# 使用最佳参数创建模型
model = CustomTransformerSeizurePredictor(
    input_dim=channels,
    output_dim=2,
    d_model=best_params.get('d_model', 512),
    n_heads=best_params.get('n_heads', 8),
    n_layers=best_params.get('n_layers', 6),
    d_ff=best_params.get('d_ff', 2048),
    dropout=best_params.get('dropout', 0.1),
    max_seq_len=time_steps,
    lr=best_params['lr'],
    weight_decay=best_params.get('weight_decay', 1e-5),
    device=device
)

# 使用与之前相同的训练流程
optimizer = optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-4)
criterion = torch.nn.KLDivLoss(reduction='batchmean')

num_epochs = 10
train_losses = []
val_losses = []
val_accuracies = []

for epoch in range(num_epochs):
    # 训练阶段
    model.train()
    total_loss = 0
    num_batches = 0
    
    for batch_idx, batch in enumerate(train_loader):
        data, labels = batch  
        data = data.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(data)
        
        # 确保输出是有效的概率分布
        outputs = torch.clamp(outputs, min=1e-8, max=1.0)
        outputs = outputs / outputs.sum(dim=1, keepdim=True)
        
        # 对输出取log
        log_outputs = torch.log(outputs)
        
        # 检查数值稳定性
        if torch.isnan(log_outputs).any() or torch.isinf(log_outputs).any():
            print(f"Warning: NaN or Inf in log_outputs at epoch {epoch+1}, batch {batch_idx}")
            continue
            
        loss = criterion(log_outputs, labels)
        
        # 检查损失值
        if torch.isnan(loss) or torch.isinf(loss):
            print(f"Warning: NaN or Inf loss at epoch {epoch+1}, batch {batch_idx}")
            continue
            
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        total_loss += loss.item()
        num_batches += 1
    
    if num_batches > 0:
        avg_train_loss = total_loss / num_batches
        train_losses.append(avg_train_loss)
    
    # 验证阶段
    val_loss, val_acc = evaluate_model_kl(model, val_loader, device=device)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)
    print(f"  Val Loss: {val_loss:.4f}, Val Accuracy: {val_acc:.2f}%")
    print("-" * 40)

# 最终评估
model.eval()
val_loss, val_acc = evaluate_model_kl(model, val_loader, device=device)
print(f"Final Validation Loss: {val_loss:.4f}")
print(f"Final Validation Accuracy: {val_acc:.2f}%")

In [None]:
# Save Model to checkpoints
model_save_path = os.path.join(model_folder, 'Transformer_best.pth')
torch.save(model.state_dict(), model_save_path)

In [None]:
# 导入通用函数
from models import (
    universal_train_model, 
    universal_hyperparameter_search,
    create_model_with_best_params,
    universal_evaluate_model
)

# 导入models.py中的模型
from models import ResNet, Wavenet, CNN1D, CustomTransformerSeizurePredictor

In [None]:
# 要比较的模型
models_to_compare = [
    ResNet,
    Wavenet,
    CustomTransformerSeizurePredictor  # 你的Transformer模型
]

results = {}

for model_class in models_to_compare:
    model_name = model_class.__name__
    print(f"处理模型: {model_name}")
    
    # 1. 超参数搜索
    best_params, study = universal_hyperparameter_search(
        model_class=model_class,
        train_loader=train_loader,
        val_loader=val_loader,
        input_dim=input_dim,
        output_dim=2,
        n_trials=20,
        device='cuda',
        model_folder=f'checkpoints/{model_name}'
    )
    
    # 2. 使用最佳参数创建最终模型
    final_model = create_model_with_best_params(
        model_class, best_params, channels, 2
    )
    
    # 3. 训练最终模型
    train_losses, val_losses, val_accuracies = universal_train_model(
        model=final_model,
        train_loader=train_loader,
        val_loader=val_loader,
        save_location=f'checkpoints/{model_name}_final',
        epochs=50,
        device='cuda'
    )
    
    # 4. 保存结果
    results[model_name] = {
        'best_params': best_params,
        'best_accuracy': max(val_accuracies),
        'model': final_model
    }

# 比较结果
for model_name, result in results.items():
    print(f"{model_name}: 准确率 {result['best_accuracy']:.4f}")

In [None]:
# Use heatmap to plot the hyperparameter search results
import matplotlib.pyplot as plt
import seaborn as sns

def plot_hyperparameter_search_results(study):
    """
    绘制Optuna超参数搜索结果的热图
    """
    # Extract only d_ff, d_model, n_heads information
    results = []
    for trial in study.trials:
        if trial.state == optuna.trial.TrialState.COMPLETE:
            params = trial.params
            results.append([
                trial.number,
                params.get('d_ff', 2048),
                params.get('d_model', 512),
                params.get('n_heads', 8),
                trial.value
            ])
    results = np.array(results)
    if results.size == 0:
        print("No completed trials found.")
        return
    # Create a DataFrame for better visualization
    import pandas as pd
    df = pd.DataFrame(results, columns=['Trial', 'd_ff', 'd_model', 'n_heads', 'Accuracy'])
    df = df.pivot_table(index='d_ff', columns=['d_model', 'n_heads'], values='Accuracy', aggfunc='mean')
    df = df.fillna(np.average(results[:,4]))  # 填充NaN值为0
    plt.figure(figsize=(12, 8))
    sns.heatmap(df, annot=True, fmt=".2f", cmap='viridis', cbar_kws={'label': 'Accuracy'})
    plt.title('Hyperparameter Search Results Heatmap')
    plt.xlabel('d_model and n_heads')
    plt.ylabel('d_ff')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()
    
plot_hyperparameter_search_results(study)


In [None]:
import torch
import os
import json

model_list = ['Transformer', 'Wavenet', 'ResNet']
MODEL_FOLDER = 'checkpoints/BestModels'
MODEL_CLASSES = {
    'Transformer': CustomTransformerSeizurePredictor,
    'Wavenet': Wavenet,
    'ResNet': ResNet
}

loaded_models = {}
device = 'cuda' if torch.cuda.is_available() else 'cpu'

for model_name in model_list:
    model_path = os.path.join(MODEL_FOLDER, f"{model_name}_best.pth")
    config_path = os.path.join(MODEL_FOLDER, f"{model_name}_config.json")
    
    if os.path.exists(model_path) and os.path.exists(config_path):
        print(f"Loading {model_name}...")
        
        try:
            # 加载配置和checkpoint
            with open(config_path, 'r') as f:
                config_data = json.load(f)
            
            checkpoint = torch.load(model_path, map_location=device)
            
            # 创建并加载模型
            model_class = MODEL_CLASSES[model_name]
            model_params = config_data['parameters']
            
            # 添加input_dim和output_dim（如果配置中没有）
            if 'input_dim' not in model_params:
                model_params['input_dim'] = 29  # 根据你的数据调整
            if 'output_dim' not in model_params:
                model_params['output_dim'] = 2   # 根据你的任务调整
            
            model = model_class(**model_params)
            model.load_state_dict(checkpoint['model_state_dict'])
            model.to(device)
            model.eval()
            
            loaded_models[model_name] = {
                'model': model,
                'config': config_data,
                'best_acc': config_data['results']['best_val_accuracy']
            }
            
            print(f"  ✓ Success (Acc: {config_data['results']['best_val_accuracy']:.4f})")
            
        except Exception as e:
            print(f"  ✗ Error: {e}")

print(f"\nLoaded {len(loaded_models)} models successfully.")