In [None]:
import os
import csv
import numpy as np
import scipy.io as sio
from scipy.signal import resample_poly
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, mean_absolute_error
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau
import random

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- Data Loading Functions ---
def load_data():
    x_seizure = np.load("x_filtered_seizures.npy")
    y_seizure = np.load("y_filtered_seizures.npy")
    meta_seizure = np.load("meta_filtered_seizures.npy", allow_pickle=True)
    x_non = np.load("x_1711_nonseizures.npy")
    y_non = np.load("y_1711_nonseizures.npy")
    meta_non = np.load("meta_1711_nonseizures.npy", allow_pickle=True)

    # Combine datasets
    x_all = np.concatenate([x_seizure, x_non])
    y_class_all = np.concatenate([y_seizure, y_non])

    # Create regression targets (onset, offset)
    y_reg_all = np.zeros((len(x_all), 2))
    y_reg_all[:len(x_seizure), 0] = meta_seizure[:, 1]  # Onset
    y_reg_all[:len(x_seizure), 1] = meta_seizure[:, 2]  # Offset

    # Print debug info
    print("=== Data Summary ===")
    print(f"Seizure samples: {x_seizure.shape[0]}")
    print(f"Non-seizure samples: {x_non.shape[0]}")
    print(f"Total samples: {x_all.shape[0]}")
    print(f"x_all shape: {x_all.shape}")
    print(f"y_class_all shape: {y_class_all.shape}")
    print(f"y_reg_all shape: {y_reg_all.shape}")
    print("====================")

    return x_all, y_class_all, y_reg_all

# --- PyTorch Dataset ---
class EEGDataset(Dataset):
    def __init__(self, x_data, y_class, y_reg):
        # Store raw data without normalization
        self.x_data = x_data
        self.y_class = y_class
        self.y_reg = y_reg

        # Create regression weights (1 for seizure, 0 for non-seizure)
        self.reg_weights = np.zeros(len(y_class))
        self.reg_weights[y_class == 1] = 1.0

    def __len__(self):
        return len(self.x_data)

    def __getitem__(self, idx):
        x = torch.tensor(self.x_data[idx], dtype=torch.float32)
        y_class = torch.tensor(self.y_class[idx], dtype=torch.float32)
        y_reg = torch.tensor(self.y_reg[idx], dtype=torch.float32)
        reg_weight = torch.tensor(self.reg_weights[idx], dtype=torch.float32)

        return x, y_class, y_reg, reg_weight

# --- Improved Hybrid Conv1D-LSTM Model ---
class SeizureModel(nn.Module):
    def __init__(self):
        super(SeizureModel, self).__init__()
        
        # Normalization layer at the start
        self.norm = nn.InstanceNorm1d(6, affine=True)
        
        # Enhanced convolutional layers with residual connections
        self.conv1 = nn.Sequential(
            nn.Conv1d(6, 64, kernel_size=25, stride=3, padding=12),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(3),
            nn.Dropout(0.2)
        )
        
        self.conv2 = nn.Sequential(
            nn.Conv1d(64, 128, kernel_size=15, stride=2, padding=7),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Dropout(0.3)
        )
        
        self.conv3 = nn.Sequential(
            nn.Conv1d(128, 256, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Dropout(0.3)
        )
        
        self.conv4 = nn.Sequential(
            nn.Conv1d(256, 512, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Dropout(0.4)
        )
        
        # Residual connections
        self.residual1 = nn.Conv1d(64, 128, kernel_size=1, stride=2)
        self.residual2 = nn.Conv1d(128, 256, kernel_size=1, stride=2)
        self.residual3 = nn.Conv1d(256, 512, kernel_size=1, stride=2)
        
        # LSTM layers with layer normalization
        self.lstm = nn.LSTM(
            input_size=512,
            hidden_size=256,
            num_layers=3,
            batch_first=True,
            bidirectional=True,
            dropout=0.4
        )
        
        # Layer normalization after LSTM
        self.ln = nn.LayerNorm(512)
        
        # Attention mechanism
        self.attention = nn.Sequential(
            nn.Linear(512, 256),
            nn.Tanh(),
            nn.Linear(256, 1),
            nn.Softmax(dim=1)
        )
        
        # Classification head
        self.class_head = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

        # Regression head
        self.reg_head = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 2)
        )

    def forward(self, x):
        # Apply instance normalization
        x = self.norm(x)
        
        # Convolutional layers with residual connections
        x1 = self.conv1(x)
        
        x2 = self.conv2(x1)
        res1 = self.residual1(x1)
        x2 = x2 + res1[:, :, :x2.size(2)]  # Match dimensions
        
        x3 = self.conv3(x2)
        res2 = self.residual2(x2)
        x3 = x3 + res2[:, :, :x3.size(2)]
        
        x4 = self.conv4(x3)
        res3 = self.residual3(x3)
        x4 = x4 + res3[:, :, :x4.size(2)]
        
        # Prepare for LSTM - swap dimensions
        x = x4.permute(0, 2, 1)
        
        # LSTM layers
        x, _ = self.lstm(x)
        
        # Apply layer normalization
        x = self.ln(x)
        
        # Attention mechanism
        attn_weights = self.attention(x)
        x = torch.sum(attn_weights * x, dim=1)
        
        # Classification output
        class_out = self.class_head(x)
        
        # Regression output
        reg_out = self.reg_head(x)
        
        return class_out, reg_out

# --- Training Functions ---
def train_model(model, train_loader, val_loader, optimizer, scheduler, criterion_cls, criterion_reg, num_epochs):
    best_val_loss = float('inf')
    train_history = {'loss': [], 'val_loss': [], 'cls_loss': [], 'val_cls_loss': [],
                     'reg_loss': [], 'val_reg_loss': [], 'acc': [], 'val_acc': []}

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        running_cls_loss = 0.0
        running_reg_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels_cls, labels_reg, reg_weights in train_loader:
            inputs = inputs.to(device)
            labels_cls = labels_cls.to(device).view(-1, 1)
            labels_reg = labels_reg.to(device)
            reg_weights = reg_weights.to(device).view(-1, 1)

            # Forward pass
            outputs_cls, outputs_reg = model(inputs)

            # Calculate losses
            cls_loss = criterion_cls(outputs_cls, labels_cls)
            reg_loss = criterion_reg(outputs_reg, labels_reg) * reg_weights
            reg_loss = reg_loss.mean()

            loss = cls_loss + 0.5 * reg_loss

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            
            # Gradient clipping
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()

            # Statistics
            running_loss += loss.item() * inputs.size(0)
            running_cls_loss += cls_loss.item() * inputs.size(0)
            running_reg_loss += reg_loss.item() * inputs.size(0)

            # Accuracy
            predicted = (outputs_cls > 0.5).float()
            total += labels_cls.size(0)
            correct += (predicted == labels_cls).sum().item()

        # Training statistics
        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_cls_loss = running_cls_loss / len(train_loader.dataset)
        epoch_reg_loss = running_reg_loss / len(train_loader.dataset)
        epoch_acc = correct / total

        # Validation phase
        val_loss, val_cls_loss, val_reg_loss, val_acc = evaluate_model(model, val_loader, criterion_cls, criterion_reg)

        # Update scheduler
        scheduler.step(val_loss)

        # Save history
        train_history['loss'].append(epoch_loss)
        train_history['val_loss'].append(val_loss)
        train_history['cls_loss'].append(epoch_cls_loss)
        train_history['val_cls_loss'].append(val_cls_loss)
        train_history['reg_loss'].append(epoch_reg_loss)
        train_history['val_reg_loss'].append(val_reg_loss)
        train_history['acc'].append(epoch_acc)
        train_history['val_acc'].append(val_acc)

        # Print progress
        print(f'Epoch [{epoch+1}/{num_epochs}], '
              f'Loss: {epoch_loss:.4f}, Cls Loss: {epoch_cls_loss:.4f}, Reg Loss: {epoch_reg_loss:.4f}, Acc: {epoch_acc:.4f}, '
              f'Val Loss: {val_loss:.4f}, Val Cls Loss: {val_cls_loss:.4f}, Val Reg Loss: {val_reg_loss:.4f}, Val Acc: {val_acc:.4f}')

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')
            print('Best model saved!')

    # Save final model
    torch.save(model.state_dict(), 'final_model.pth')
    print('Final model saved!')

    return train_history

def evaluate_model(model, loader, criterion_cls, criterion_reg):
    model.eval()
    running_loss = 0.0
    running_cls_loss = 0.0
    running_reg_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels_cls, labels_reg, reg_weights in loader:
            inputs = inputs.to(device)
            labels_cls = labels_cls.to(device).view(-1, 1)
            labels_reg = labels_reg.to(device)
            reg_weights = reg_weights.to(device).view(-1, 1)

            # Forward pass
            outputs_cls, outputs_reg = model(inputs)

            # Calculate losses
            cls_loss = criterion_cls(outputs_cls, labels_cls)
            reg_loss = criterion_reg(outputs_reg, labels_reg) * reg_weights
            reg_loss = reg_loss.mean()

            loss = cls_loss + 0.5 * reg_loss

            # Statistics
            running_loss += loss.item() * inputs.size(0)
            running_cls_loss += cls_loss.item() * inputs.size(0)
            running_reg_loss += reg_loss.item() * inputs.size(0)

            # Accuracy
            predicted = (outputs_cls > 0.5).float()
            total += labels_cls.size(0)
            correct += (predicted == labels_cls).sum().item()

    # Calculate metrics
    loss = running_loss / len(loader.dataset)
    cls_loss = running_cls_loss / len(loader.dataset)
    reg_loss = running_reg_loss / len(loader.dataset)
    acc = correct / total

    return loss, cls_loss, reg_loss, acc

# --- Plotting Functions ---
def plot_history(history):
    plt.figure(figsize=(15, 10))

    # Overall loss
    plt.subplot(2, 2, 1)
    plt.plot(history['loss'], label='Train')
    plt.plot(history['val_loss'], label='Validation')
    plt.title('Overall Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend()

    # Classification loss
    plt.subplot(2, 2, 2)
    plt.plot(history['cls_loss'], label='Train')
    plt.plot(history['val_cls_loss'], label='Validation')
    plt.title('Classification Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend()

    # Regression loss
    plt.subplot(2, 2, 3)
    plt.plot(history['reg_loss'], label='Train')
    plt.plot(history['val_reg_loss'], label='Validation')
    plt.title('Regression Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend()

    # Accuracy
    plt.subplot(2, 2, 4)
    plt.plot(history['acc'], label='Train')
    plt.plot(history['val_acc'], label='Validation')
    plt.title('Classification Accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend()

    plt.tight_layout()
    plt.savefig('training_history.png')
    plt.close()

def plot_regression(y_true, y_pred, title, filename):
    plt.figure(figsize=(8, 6))
    plt.scatter(y_true, y_pred, alpha=0.5)
    plt.plot([0, 180], [0, 180], 'r--')
    plt.xlabel('True Values')
    plt.ylabel('Predictions')
    plt.title(title)
    plt.grid(True)
    plt.savefig(filename)
    plt.close()

# --- Main Function ---
def main():
    # Load data
    x_all, y_class_all, y_reg_all = load_data()

    # Split data
    x_train, x_test, y_class_train, y_class_test, y_reg_train, y_reg_test = train_test_split(
        x_all, y_class_all, y_reg_all, test_size=0.2, stratify=y_class_all, random_state=42
    )

    # Create datasets
    train_dataset = EEGDataset(x_train, y_class_train, y_reg_train)
    test_dataset = EEGDataset(x_test, y_class_test, y_reg_test)

    # Split into train and validation
    train_size = int(0.85 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset, val_dataset = random_split(
        train_dataset, [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )

    # Create data loaders
    batch_size = 16
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    # Initialize model
    model = SeizureModel().to(device)

    # Print model summary
    print("Model architecture:")
    print(model)
    
    # Print number of parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")

    # Loss functions
    criterion_cls = nn.BCELoss()
    criterion_reg = nn.SmoothL1Loss(reduction='none')  # Less sensitive to outliers

    # Optimizer with weight decay
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

    # Learning rate scheduler
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True, min_lr=1e-6)

    # Train model
    history = train_model(
        model,
        train_loader,
        val_loader,
        optimizer,
        scheduler,
        criterion_cls,
        criterion_reg,
        num_epochs=100
    )

    # Plot training history
    plot_history(history)

    # Evaluate on test set
    print("\nEvaluating on test set...")
    model.load_state_dict(torch.load('best_model.pth'))
    model.eval()

    all_preds_cls = []
    all_preds_reg = []
    all_labels_cls = []
    all_labels_reg = []
    all_reg_weights = []

    with torch.no_grad():
        for inputs, labels_cls, labels_reg, reg_weights in test_loader:
            inputs = inputs.to(device)

            # Forward pass
            preds_cls, preds_reg = model(inputs)

            # Collect results
            all_preds_cls.append(preds_cls.cpu().numpy())
            all_preds_reg.append(preds_reg.cpu().numpy())
            all_labels_cls.append(labels_cls.numpy())
            all_labels_reg.append(labels_reg.numpy())
            all_reg_weights.append(reg_weights.numpy())

    # Concatenate results
    all_preds_cls = np.concatenate(all_preds_cls).squeeze()
    all_preds_reg = np.concatenate(all_preds_reg)
    all_labels_cls = np.concatenate(all_labels_cls)
    all_labels_reg = np.concatenate(all_labels_reg)
    all_reg_weights = np.concatenate(all_reg_weights)

    # Classification metrics
    preds_cls_binary = (all_preds_cls > 0.5).astype(int)
    print("\nClassification Report:")
    print(classification_report(all_labels_cls, preds_cls_binary))

    # Regression metrics (only on seizure samples)
    seizure_mask = (all_labels_cls == 1)
    if np.any(seizure_mask):
        seizure_preds_reg = all_preds_reg[seizure_mask]
        seizure_labels_reg = all_labels_reg[seizure_mask]

        onset_mae = mean_absolute_error(seizure_labels_reg[:, 0], seizure_preds_reg[:, 0])
        offset_mae = mean_absolute_error(seizure_labels_reg[:, 1], seizure_preds_reg[:, 1])

        print("\nRegression Performance (Seizure Samples Only):")
        print(f"Onset MAE: {onset_mae:.2f} seconds")
        print(f"Offset MAE: {offset_mae:.2f} seconds")

        # Plot regression results
        plot_regression(
            seizure_labels_reg[:, 0], seizure_preds_reg[:, 0],
            'Onset Prediction', 'onset_prediction.png'
        )
        plot_regression(
            seizure_labels_reg[:, 1], seizure_preds_reg[:, 1],
            'Offset Prediction', 'offset_prediction.png'
        )
    else:
        print("\nNo seizure samples in test set for regression evaluation")

if __name__ == "__main__":
    main()