In [None]:
import numpy as np
# At the start of your notebook
from IPython.display import clear_output
import gc

# After heavy computations
clear_output(wait=True)
gc.collect()

In [None]:
RESULT_FOLDER = "result"
MODEL_FOLDER = "model"
model_names = ['Wavenet']  # 'CNN1D', 'Wavenet', 'S4', 'Resnet'

In [None]:
from steps import extract_sEEG_features
from datasetConstruct import load_seizure_across_patients

dataset = load_seizure_across_patients(data_folder='data')

# for seizure in dataset:
#     seizure_new = extract_sEEG_features(seizure, sampling_rate=seizure.samplingRate)

In [None]:
from models import EnhancedResNet

EnhancedResNet 单元测试
这个 Notebook 包含了对 EnhancedResNet 模型的单元测试，专门设计用于验证模型在癫痫发作数据上的性能。
## 测试内容包括
1. 模型初始化测试
2. 前向传播测试
3. 损失函数测试 (解剖约束损失、癫痫通道损失)
4. 模型收敛测试
5. 真实数据测试

In [None]:
import unittest
import torch
import numpy as np
import os
import tempfile
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt

# Import the necessary modules from your code
from steps import train_using_optimizer_with_masks

In [None]:
class SyntheticSeizureDataset(Dataset):
    def __init__(self, data, labels, seizure_mask, grey_matter_values, channel_idx, time_idx):
        self.data = torch.tensor(data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)
        self.seizure_mask = torch.tensor(seizure_mask, dtype=torch.bool)
        self.grey_matter_values = torch.tensor(grey_matter_values, dtype=torch.float32)
        self.channel_idx = torch.tensor(channel_idx, dtype=torch.long)
        self.time_idx = torch.tensor(time_idx, dtype=torch.long)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return {
            'data': self.data[idx],  # Shape: (1, time_steps)
            'label': self.labels[idx],
            'seizure_mask': self.seizure_mask[idx],
            'grey_matter_values': self.grey_matter_values[idx],
            'channel_idx': self.channel_idx[idx],
            'time_idx': self.time_idx[idx]
        }


def create_synthetic_seizure_data(n_samples=64, n_channels=40, time_steps=128, batch_size=16, feature_dim=11):
    """
    Generate synthetic seizure dataset with expanded feature dimension.
    """
    grey_matter_map = np.zeros(n_channels)
    grey_matter_map[:20] = 1.0
    grey_matter_map[20:25] = 0.5
    grey_matter_map[25:] = 0.0

    all_samples = []
    all_labels = []
    all_seizure_masks = []
    all_grey_matter_values = []
    all_channel_indices = []
    all_time_indices = []

    for sample_idx in range(n_samples):
        if sample_idx < n_samples // 2:
            label = 1
            t = np.linspace(0, 1, time_steps)
            base_signal = np.sin(2 * np.pi * 20 * t) + 0.5*np.sin(2*np.pi*40*t) + 0.2*np.random.randn(time_steps)
        else:
            label = 0
            t = np.linspace(0, 1, time_steps)
            base_signal = 0.5*np.sin(2*np.pi*5*t) + 0.2*np.sin(2*np.pi*10*t) + 0.3*np.random.randn(time_steps)

        for ch_idx in range(n_channels):
            feature_matrix = np.zeros((feature_dim, time_steps))
            for f in range(feature_dim):
                feature_matrix[f] = base_signal + 0.01*np.random.randn(time_steps)

            all_samples.append(feature_matrix)
            all_labels.append(label)
            all_grey_matter_values.append(grey_matter_map[ch_idx])
            all_channel_indices.append(ch_idx)
            all_time_indices.append(sample_idx)

            if label == 1 and ch_idx < 8:
                all_seizure_masks.append(1)
            else:
                all_seizure_masks.append(0)

    all_samples = np.stack(all_samples)
    all_labels = np.array(all_labels)
    all_seizure_masks = np.array(all_seizure_masks)
    all_grey_matter_values = np.array(all_grey_matter_values)
    all_channel_indices = np.array(all_channel_indices)
    all_time_indices = np.array(all_time_indices)

    idx = np.random.permutation(len(all_samples))
    all_samples = all_samples[idx]
    all_labels = all_labels[idx]
    all_seizure_masks = all_seizure_masks[idx]
    all_grey_matter_values = all_grey_matter_values[idx]
    all_channel_indices = all_channel_indices[idx]
    all_time_indices = all_time_indices[idx]

    split = int(0.8 * len(all_samples))
    train_dataset = SyntheticSeizureDataset(
        all_samples[:split], all_labels[:split], all_seizure_masks[:split],
        all_grey_matter_values[:split], all_channel_indices[:split], all_time_indices[:split]
    )
    val_dataset = SyntheticSeizureDataset(
        all_samples[split:], all_labels[split:], all_seizure_masks[split:],
        all_grey_matter_values[split:], all_channel_indices[split:], all_time_indices[split:]
    )

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader


# Create synthetic data
train_loader, val_loader = create_synthetic_seizure_data(
    n_samples=64, 
    n_channels=40, 
    time_steps=128,
    batch_size=16
)

In [None]:
def test_model_initialization():
    """Test that the model initializes correctly"""
    # Define model parameters
    input_dim = 11  # Number of features
    output_dim = 2  # Binary classification
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    
    # Create a model instance
    model = EnhancedResNet(
        input_dim=input_dim,
        output_dim=output_dim,
        base_filters=32,
        n_blocks=3,
        kernel_size=16,
        dropout=0.2,
        lr=0.001,
        weight_decay=1e-5,
    )
    
    # Check model parameters
    assert model.input_dim == input_dim, f"Expected input_dim {input_dim}, got {model.input_dim}"
    assert model.output_dim == output_dim, f"Expected output_dim {output_dim}, got {model.output_dim}"
    print("Model initialization test passed!")
    
    return model

# Run the test
model = test_model_initialization()

In [None]:
def test_convergence_on_synthetic_data(model, epochs=10):
    temp_dir = tempfile.mkdtemp()
    print(f"Using temporary directory: {temp_dir}")

    n_channels = 40
    time_steps = 128
    batch_size = 16

    train_loader, val_loader = create_synthetic_seizure_data(
        n_samples=64,
        n_channels=n_channels,
        time_steps=time_steps,
        batch_size=batch_size
    )

    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")

    train_loss, val_loss, val_accuracy = train_using_optimizer_with_masks(
        model=model,
        trainloader=train_loader,
        valloader=val_loader,
        save_location=temp_dir,
        epochs=epochs,
        device=device,
        patience=5,
        scheduler_patience=3,
        checkpoint_freq=5
    )

    assert train_loss[-1] < train_loss[0], f"Training loss didn't decrease: {train_loss[0]} -> {train_loss[-1]}"
    assert val_accuracy[-1] > 0.5, f"Validation accuracy didn't improve above random: {val_accuracy[-1]}"

    print(f"Initial training loss: {train_loss[0]:.4f}")
    print(f"Final training loss: {train_loss[-1]:.4f}")
    print(f"Initial validation accuracy: {val_accuracy[0]:.4f}")
    print(f"Final validation accuracy: {val_accuracy[-1]:.4f}")
    print("Convergence test passed!")

    for file in os.listdir(temp_dir):
        os.remove(os.path.join(temp_dir, file))
    os.rmdir(temp_dir)

    return train_loss, val_loss, val_accuracy

# Example run
model = EnhancedResNet(input_dim=11, output_dim=2)  # Example model
train_loss, val_loss, val_accuracy = test_convergence_on_synthetic_data(model, epochs=5)

# Plotting
def plot_training_curves(train_loss, val_loss, val_accuracy):
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_loss, label='Training Loss')
    plt.plot(val_loss, label='Validation Loss')
    plt.legend()
    plt.title('Loss Curves')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')

    plt.subplot(1, 2, 2)
    plt.plot(val_accuracy, label='Validation Accuracy')
    plt.legend()
    plt.title('Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')

    plt.tight_layout()
    plt.show()

In [None]:
def test_real_seizure_data_convergence(
    model, 
    patient_no=65, 
    seizure_no=3, 
    data_folder="data", 
    epochs=5, 
    batch_size=128, 
    input_type='transformed',
    clean_up=True
):
    """Test model convergence on a real seizure data sample"""
    import tempfile, os
    try:
        from steps import load_single_seizure, create_dataset
        
        temp_dir = tempfile.mkdtemp()
        print(f"Using temporary directory: {temp_dir}")
    
        device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {device}")
        
        single_seizure_folder = os.path.join(data_folder, f"P{patient_no}")
    
        if not os.path.exists(single_seizure_folder):
            print(f"No seizure data found for patient {patient_no}")
            return None, None, None
    
        print(f"Loading seizure {seizure_no} for patient {patient_no}")
        seizure_obj = load_single_seizure(single_seizure_folder, seizure_no)
    
        # create dataset
        train_loader, val_loader = create_dataset(
            seizure=seizure_obj,
            train_percentage=0.8,
            batch_size=batch_size,
            input_type=input_type
        )
    
        # Optional: reset model weights if needed
        model.train()
    
        # Train
        train_loss, val_loss, val_accuracy = train_using_optimizer_with_masks(
            model=model,
            trainloader=train_loader,
            valloader=val_loader,
            save_location=temp_dir,
            epochs=epochs,
            device=device,
            patience=3,
            scheduler_patience=2,
            checkpoint_freq=2
        )
    
        # Assertions
        assert train_loss[-1] < train_loss[0], f"Training loss didn't decrease: {train_loss[0]} -> {train_loss[-1]}"
        assert val_accuracy[-1] > 0.5, f"Validation accuracy didn't improve above random: {val_accuracy[-1]}"
    
        print(f"Initial Training Loss: {train_loss[0]:.4f} -> Final: {train_loss[-1]:.4f}")
        print(f"Initial Validation Accuracy: {val_accuracy[0]:.4f} -> Final: {val_accuracy[-1]:.4f}")
        print("Real data convergence test passed!")
    
        if clean_up:
            for file in os.listdir(temp_dir):
                os.remove(os.path.join(temp_dir, file))
            os.rmdir(temp_dir)
    
        return train_loss, val_loss, val_accuracy
        
    except (ImportError, ModuleNotFoundError) as e:
        print(f"Required modules not available for real data testing: {e}")
        return None, None, None
    except FileNotFoundError as e:
        print(f"Real seizure data not available for testing: {e}")
        return None, None, None
    finally:
        # Clean up temporary directory
        if os.path.exists(temp_dir):
            for file in os.listdir(temp_dir):
                os.remove(os.path.join(temp_dir, file))
            os.rmdir(temp_dir)
model = EnhancedResNet(
    input_dim=11,
    output_dim=2,
    base_filters=32,
    n_blocks=3,
    kernel_size=16,
    dropout=0.2,
    lr=0.001,
    weight_decay=1e-5
)

# Run the test
real_train_loss, real_val_loss, real_val_accuracy = test_real_seizure_data_convergence(model)

# Plot training curves if data is available
if real_train_loss is not None:
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(real_train_loss, label='Training Loss')
    plt.plot(real_val_loss, label='Validation Loss')
    plt.legend()
    plt.title('Loss Curves (Real Data)')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')

    plt.subplot(1, 2, 2)
    plt.plot(real_val_accuracy, label='Validation Accuracy')
    plt.legend()
    plt.title('Validation Accuracy (Real Data)')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')

    plt.tight_layout()
    plt.show()

In [None]:
from steps import setup_and_train_models

results, models = setup_and_train_models(
    data_folder="data",
    model_folder="checkpoints",
    model_names=['EnhancedResNet'],
    train=True,
    input_type='transformed',  # 'transformed' or 'raw'
    params={'epochs': 100, 'batch_size': 2048, 'checkpoint_freq': 20},  # params: epochs, checkpoint_freq, lr, batch_size, device, patience, gradient_clip
    hyperparameter_search=False
)

In [None]:
def test_performance_with_and_without_constraints(epochs=5):
    """Compare model performance with and without anatomical constraints"""
    # Create temporary directory for model checkpoints
    temp_dir = tempfile.mkdtemp()
    temp_dir_with = os.path.join(temp_dir, "with_constraints")
    temp_dir_without = os.path.join(temp_dir, "without_constraints")
    os.makedirs(temp_dir_with, exist_ok=True)
    os.makedirs(temp_dir_without, exist_ok=True)
    
    # Set device
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")
    
    # Create synthetic seizure data
    n_channels = 32
    time_steps = 128
    batch_size = 16
    train_loader, val_loader = create_synthetic_seizure_data(
        n_samples=64, 
        n_channels=n_channels, 
        time_steps=time_steps,
        batch_size=batch_size
    )
    
    try:
        # First, train the model with constraints
        model_with = EnhancedResNet(
            input_dim=12,
            output_dim=2,
            base_filters=32,
            n_blocks=3,
            kernel_size=16,
            dropout=0.2,
            lr=0.001,
            weight_decay=1e-5,
            gamma=0.5  # Weight for anatomical constraint
        )
        
        print("Training model WITH anatomical constraints...")
        train_loss_with, val_loss_with, val_accuracy_with = train_using_optimizer_with_masks(
            model=model_with,
            trainloader=train_loader,
            valloader=val_loader,
            save_location=temp_dir_with,
            epochs=epochs,
            device=device,
            patience=5,
            scheduler_patience=3,
            checkpoint_freq=epochs
        )
        
        # Next, train the model without constraints
        model_without = EnhancedResNet(
            input_dim=12,
            output_dim=2,
            base_filters=32,
            n_blocks=3,
            kernel_size=16,
            dropout=0.2,
            lr=0.001,
            weight_decay=1e-5,
            gamma=0.0  # No anatomical constraint
        )
        
        print("\nTraining model WITHOUT anatomical constraints...")
        train_loss_without, val_loss_without, val_accuracy_without = train_using_optimizer_with_masks(
            model=model_without,
            trainloader=train_loader,
            valloader=val_loader,
            save_location=temp_dir_without,
            epochs=epochs,
            device=device,
            patience=5,
            scheduler_patience=3,
            checkpoint_freq=epochs
        )
        
        # Print comparison results
        print("\n--- Performance Comparison ---")
        print(f"WITH constraints - Final val accuracy: {val_accuracy_with[-1]:.4f}")
        print(f"WITHOUT constraints - Final val accuracy: {val_accuracy_without[-1]:.4f}")
        
        # Plot comparison
        plt.figure(figsize=(15, 5))
        
        plt.subplot(1, 3, 1)
        plt.plot(train_loss_with, label='With Constraints')
        plt.plot(train_loss_without, label='Without Constraints')
        plt.legend()
        plt.title('Training Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        
        plt.subplot(1, 3, 2)
        plt.plot(val_loss_with, label='With Constraints')
        plt.plot(val_loss_without, label='Without Constraints')
        plt.legend()
        plt.title('Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        
        plt.subplot(1, 3, 3)
        plt.plot(val_accuracy_with, label='With Constraints')
        plt.plot(val_accuracy_without, label='Without Constraints')
        plt.legend()
        plt.title('Validation Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        
        plt.tight_layout()
        plt.show()
        
        return (
            (train_loss_with, val_loss_with, val_accuracy_with),
            (train_loss_without, val_loss_without, val_accuracy_without)
        )
        
    finally:
        # Clean up temporary directory
        if os.path.exists(temp_dir):
            for subdir in [temp_dir_with, temp_dir_without]:
                if os.path.exists(subdir):
                    for file in os.listdir(subdir):
                        os.remove(os.path.join(subdir, file))
                    os.rmdir(subdir)
            os.rmdir(temp_dir)

# Run the comparison test with fewer epochs for quicker testing
results_with, results_without = test_performance_with_and_without_constraints(epochs=5)