In [None]:
# %% [markdown]
# # Multimodal CNN + TCN with Feature Fusion
# 
# This notebook implements a multimodal approach by combining a Convolutional Neural Network (CNN) trained on EEG spectrograms and a Temporal Convolutional Network (TCN) trained on raw EEG signals.
# 
# Due to limited computing resources, we'll use a three-stage approach:
# 1.  **Extract Features**: Load the trained models for each fold and extract feature embeddings from an intermediate layer. Save these features to disk.
# 2.  **Create Fusion Dataset**: Build a new PyTorch Dataset that loads the pre-saved CNN and TCN features and concatenates them.
# 3.  **Train Classifier Head**: Train a simple MLP classifier head on the fused features.

# %%
import os
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from sklearn.model_selection import StratifiedGroupKFold
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm

# Add project root to system path
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "..", "..", "..")))
from src.datasets.eeg_dataset import EEGDataset
from src.datasets.multi_spectrogram import MultiSpectrogramDataset
from src.models.base_cnn import BaseCNN
from src.models.tcn import TCNModel
from src.utils.constants import Constants
from src.utils.k_folds_creator import KFoldCreator
from src.utils.utils import get_models_save_path

# Log in to W&B
wandb.login()

# %% [markdown]
# ## 1. Configuration
# 
# We'll define a configuration class `CFG` to hold all hyperparameters and paths for both models and the final classifier head.

# %%
class CFG:
    # General
    seed = 42
    n_splits = 5
    data_path = '../../../data/'
    
    # Paths
    models_save_path = get_models_save_path()
    feature_save_path = Path('../../../data/features/')
    DATA_PREPARATION_VOTE_METHOD = "sum_and_normalize"
    
    # CNN Config
    cnn_model_name = 'tf_efficientnet_b0_ns'
    cnn_in_channels = 8
    cnn_img_size = (128, 256)
    cnn_eeg_spec_path = '../../../data/custom_eegs/cwt'
    
    # TCN Config
    tcn_model_name = 'TCN'
    tcn_num_channels = 20
    tcn_num_tcn_channels = [64, 128, 128, 256, 256, 512, 512, 512]
    tcn_kernel_size = 21
    tcn_dropout = 0.35
    tcn_sequence_duration = 50
    tcn_original_sampling_rate = 200
    tcn_downsample_factor = 3
    tcn_sampling_rate = tcn_original_sampling_rate // tcn_downsample_factor
    tcn_sequence_length = tcn_sequence_duration * tcn_sampling_rate
    
    # Classifier Head Config
    fusion_batch_size = 64
    fusion_num_workers = 4
    fusion_epochs = 10
    fusion_lr = 5e-4
    fusion_hidden_size = 512
    target_size = 6

# Create directory for saving features
CFG.feature_save_path.mkdir(parents=True, exist_ok=True)

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

set_seed(CFG.seed)
TARGETS = Constants.TARGETS
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

# %% [markdown]
# ## 2. Data Preparation
# 
# We'll load the processed data and create the same stratified group k-folds that were used to train the individual models. This ensures consistency.

# %%
print("Preparing data and creating folds...")
df = pd.read_csv(CFG.data_path + 'processed_data_sum_votes_window.csv')
df['expert_consensus'] = df[TARGETS].idxmax(axis=1)

print('Train shape:', df.shape)

fold_creator = KFoldCreator(n_splits=CFG.n_splits, seed=CFG.seed)
df = fold_creator.create_folds(df, stratify_col='expert_consensus', group_col='patient_id')

print("Folds created. Value counts per fold:")
print(df['fold'].value_counts())

# %% [markdown]
# ## 3. Stage 1: Feature Extraction
# 
# ### Modifying Models for Feature Extraction
# We need to adapt the `BaseCNN` and `TCNModel` to return feature vectors instead of classification logits. We'll create new wrapper classes for this purpose.

# %%
class CNNFeatureExtractor(nn.Module):
    def __init__(self, model_name, pretrained=True, in_channels=4, num_classes=6):
        super().__init__()
        # Load the original model structure
        original_model = BaseCNN(model_name, pretrained, in_channels, num_classes)
        # We use the feature extractor part of the timm model
        self.feature_extractor = original_model.model.forward_features
        # Keep the forward logic from BaseCNN
        self.forward_logic = original_model.forward
        
    def forward(self, x):
        # We need to replicate the input processing from the original BaseCNN
        channels = torch.split(x, 1, dim=1)
        x_reshaped = torch.cat(channels, dim=2)
        x_3_channel = x_reshaped.repeat(1, 3, 1, 1)
        
        # Now pass it through the feature extractor
        features = self.feature_extractor(x_3_channel)
        
        # Global average pooling
        return F.adaptive_avg_pool2d(features, (1, 1)).squeeze()


class TCNFeatureExtractor(nn.Module):
    def __init__(self, num_inputs, channel_sizes, kernel_size, dropout):
        super().__init__()
        self.tcn = TCN(
            num_inputs=num_inputs,
            num_channels=channel_sizes,
            kernel_size=kernel_size,
            dropout=dropout,
            causal=False,
            use_skip_connections=True
        )

    def forward(self, x):
        x = x.permute(0, 2, 1)
        tcn_output = self.tcn(x)
        # Return the output of the last time step
        return tcn_output[:, :, -1]

# %% [markdown]
# ### Feature Extraction Loop
# 
# This function will iterate through a dataloader, pass the data through a feature extractor model, and save the resulting features, labels, and `eeg_id`s.

# %%
def extract_and_save_features(model, dataloader, save_path, device, fold_id):
    model.to(device)
    model.eval()
    
    all_features = []
    all_labels = []
    
    with torch.no_grad():
        for batch_data, batch_labels in tqdm(dataloader, desc=f"Extracting Fold {fold_id} Features"):
            batch_data = batch_data.to(device)
            
            # The autocast context manager is now removed for inference
            features = model(batch_data)
                
            all_features.append(features.cpu().numpy())
            all_labels.append(batch_labels.cpu().numpy())
            
    # Concatenate all batches
    all_features = np.concatenate(all_features, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)
    
    # Save to disk
    np.save(save_path / f'features_fold_{fold_id}.npy', all_features)
    np.save(save_path / f'labels_fold_{fold_id}.npy', all_labels)
    print(f"Saved features for fold {fold_id} to {save_path}")

# %% [markdown]
# ### Run CNN Feature Extraction

# %%
# CNN Feature Extraction
cnn_feature_path = CFG.feature_save_path / "cnn"
cnn_feature_path.mkdir(exist_ok=True, parents=True)

for fold in range(CFG.n_splits):
    print(f"\n===== Extracting CNN Features for Fold {fold} =====")
    
    # Check if features already exist
    if (cnn_feature_path / f'features_fold_{fold}.npy').exists():
        print(f"CNN features for fold {fold} already exist. Skipping.")
        continue

    # 1. Load trained model weights
    cnn_extractor = CNNFeatureExtractor(CFG.cnn_model_name, pretrained=False)
    model_path = CFG.models_save_path / "MultiSpectCNN" / CFG.DATA_PREPARATION_VOTE_METHOD / f'best_model_fold{fold}.pth'
    
    # We need to load the state dict from the original BaseCNN into our extractor
    original_state_dict = torch.load(model_path, map_location=DEVICE)
    # The extractor's layers are a subset of the original model's layers
    cnn_extractor.load_state_dict(original_state_dict, strict=False)

    # 2. Create dataset for the entire fold (train + valid)
    fold_df = df.copy() # Use the full dataframe
    dataset = MultiSpectrogramDataset(
        fold_df, TARGETS, CFG.data_path, CFG.cnn_img_size, CFG.cnn_eeg_spec_path, mode='train'
    )
    dataloader = DataLoader(
        dataset, batch_size=CFG.fusion_batch_size, shuffle=False, num_workers=CFG.fusion_num_workers
    )
    
    # 3. Extract and save
    extract_and_save_features(cnn_extractor, dataloader, cnn_feature_path, DEVICE, fold)

# %% [markdown]
# ### Run TCN Feature Extraction

# %%
# TCN Feature Extraction
tcn_feature_path = CFG.feature_save_path / "tcn"
tcn_feature_path.mkdir(exist_ok=True, parents=True)

for fold in range(CFG.n_splits):
    print(f"\n===== Extracting TCN Features for Fold {fold} =====")

    if (tcn_feature_path / f'features_fold_{fold}.npy').exists():
        print(f"TCN features for fold {fold} already exist. Skipping.")
        continue

    # 1. Load trained model weights
    tcn_extractor = TCNFeatureExtractor(
        num_inputs=CFG.tcn_num_channels,
        channel_sizes=CFG.tcn_num_tcn_channels,
        kernel_size=CFG.tcn_kernel_size,
        dropout=CFG.tcn_dropout,
    )
    model_path = CFG.models_save_path / "TCNModel" / CFG.DATA_PREPARATION_VOTE_METHOD / f'best_model_fold{fold}.pth'
    original_state_dict = torch.load(model_path, map_location=DEVICE)
    tcn_extractor.load_state_dict(original_state_dict, strict=False)
    
    # 2. Create dataset for the entire fold
    fold_df = df.copy()
    dataset = EEGDataset(
        df=fold_df, data_path=CFG.data_path, mode='train',
        downsample_factor=CFG.tcn_downsample_factor
    )
    dataloader = DataLoader(
        dataset, batch_size=CFG.fusion_batch_size, shuffle=False, num_workers=CFG.fusion_num_workers
    )
    
    # 3. Extract and save
    extract_and_save_features(tcn_extractor, dataloader, tcn_feature_path, DEVICE, fold)


# %% [markdown]
# ## 4. Stage 2: Feature Fusion Dataset
# 
# Now we create a `Dataset` that loads the features we just saved. This will be very fast as it only involves reading NumPy arrays from disk.

# %%
class FusedFeatureDataset(Dataset):
    def __init__(self, df, fold_id, cnn_feature_path, tcn_feature_path, indices):
        self.indices = indices
        
        # Load all features and labels for the given fold into memory
        self.cnn_features = np.load(cnn_feature_path / f'features_fold_{fold_id}.npy')
        self.tcn_features = np.load(tcn_feature_path / f'features_fold_{fold_id}.npy')
        self.labels = np.load(cnn_feature_path / f'labels_fold_{fold_id}.npy') # Labels are the same

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        # Get the original dataframe index
        original_idx = self.indices[idx]
        
        cnn_feat = torch.tensor(self.cnn_features[original_idx], dtype=torch.float32)
        tcn_feat = torch.tensor(self.tcn_features[original_idx], dtype=torch.float32)
        
        # Concatenate features
        fused_features = torch.cat([cnn_feat, tcn_feat], dim=0)
        
        label = torch.tensor(self.labels[original_idx], dtype=torch.float32)
        
        return fused_features, label

# %% [markdown]
# ## 5. Stage 3: Classifier Head and Training Loop
# 
# ### Define the Classifier Head
# A simple MLP will serve as our classifier head. It takes the concatenated feature vector as input.

# %%
class FusionClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.bn1 = nn.BatchNorm1d(hidden_size)
        self.dropout1 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(hidden_size, hidden_size // 2)
        self.bn2 = nn.BatchNorm1d(hidden_size // 2)
        self.dropout2 = nn.Dropout(0.3)
        self.fc3 = nn.Linear(hidden_size // 2, num_classes)

    def forward(self, x):
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.dropout1(x)
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.dropout2(x)
        x = self.fc3(x)
        return x

# %% [markdown]
# ### Training and Evaluation Loop

# %%
def run_fusion_training(df, cnn_feature_path, tcn_feature_path):
    all_oof_preds = []
    all_oof_labels = []

    for fold in range(CFG.n_splits):
        print(f"\n========== FOLD {fold} ==========")
        
        # Determine input size for the classifier head
        cnn_feat_sample = np.load(cnn_feature_path / f'features_fold_{fold}.npy', mmap_mode='r')
        tcn_feat_sample = np.load(tcn_feature_path / f'features_fold_{fold}.npy', mmap_mode='r')
        input_size = cnn_feat_sample.shape[1] + tcn_feat_sample.shape[1]
        print(f"Input feature size: {input_size} (CNN: {cnn_feat_sample.shape[1]}, TCN: {tcn_feat_sample.shape[1]})")

        # Get train and validation indices based on folds
        train_indices = df[df['fold'] != fold].index.values
        valid_indices = df[df['fold'] == fold].index.values

        # Create datasets
        train_dataset = FusedFeatureDataset(df, fold, cnn_feature_path, tcn_feature_path, train_indices)
        valid_dataset = FusedFeatureDataset(df, fold, cnn_feature_path, tcn_feature_path, valid_indices)

        train_loader = DataLoader(
            train_dataset, batch_size=CFG.fusion_batch_size, shuffle=True, 
            num_workers=CFG.fusion_num_workers, drop_last=True
        )
        valid_loader = DataLoader(
            valid_dataset, batch_size=CFG.fusion_batch_size, shuffle=False, 
            num_workers=CFG.fusion_num_workers
        )

        model = FusionClassifier(input_size, CFG.fusion_hidden_size, CFG.target_size).to(DEVICE)
        optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.fusion_lr)
        scheduler = CosineAnnealingLR(optimizer, T_max=CFG.fusion_epochs)
        loss_fn = nn.KLDivLoss(reduction='batchmean')
        
        best_val_loss = float('inf')

        for epoch in range(CFG.fusion_epochs):
            model.train()
            train_loss = 0
            for features, labels in tqdm(train_loader, desc=f"Epoch {epoch+1} Training"):
                features, labels = features.to(DEVICE), labels.to(DEVICE)
                
                optimizer.zero_grad()
                outputs = model(features)
                log_probs = F.log_softmax(outputs, dim=1)
                loss = loss_fn(log_probs, labels)
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item()
            
            train_loss /= len(train_loader)

            model.eval()
            valid_loss = 0
            with torch.no_grad():
                for features, labels in valid_loader:
                    features, labels = features.to(DEVICE), labels.to(DEVICE)
                    outputs = model(features)
                    log_probs = F.log_softmax(outputs, dim=1)
                    loss = loss_fn(log_probs, labels)
                    valid_loss += loss.item()
            
            valid_loss /= len(valid_loader)
            
            print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Valid Loss = {valid_loss:.4f}")

            if valid_loss < best_val_loss:
                best_val_loss = valid_loss
                torch.save(model.state_dict(), f'fusion_head_best_fold_{fold}.pth')
            
            scheduler.step()

        # OOF Predictions
        model.load_state_dict(torch.load(f'fusion_head_best_fold_{fold}.pth'))
        model.eval()
        fold_preds = []
        fold_labels = []
        with torch.no_grad():
            for features, labels in valid_loader:
                features = features.to(DEVICE)
                outputs = model(features)
                probs = F.softmax(outputs, dim=1).cpu().numpy()
                fold_preds.append(probs)
                fold_labels.append(labels.numpy())
        
        all_oof_preds.append(np.concatenate(fold_preds))
        all_oof_labels.append(np.concatenate(fold_labels))

    # Calculate final OOF score
    final_oof_preds = np.concatenate(all_oof_preds)
    final_oof_labels = np.concatenate(all_oof_labels)
    
    kl_loss = nn.KLDivLoss(reduction='batchmean')
    score = kl_loss(torch.log(torch.tensor(final_oof_preds)), torch.tensor(final_oof_labels)).item()
    
    print(f"\nOverall OOF KL Score: {score:.4f}")
    return score

# %% [markdown]
# ## 6. Run the Final Training

# %%
if __name__ == '__main__':
    overall_oof_score = run_fusion_training(df, cnn_feature_path, tcn_feature_path)

# %%
print(f"Final OOF KL Score from Multimodal Fusion: {overall_oof_score:.4f}")

Using device: cuda
Preparing data and creating folds...
Train shape: (17089, 12)
Folds created. Value counts per fold:
fold
0    3741
1    3703
2    3527
4    3081
3    3037
Name: count, dtype: int64

===== Extracting CNN Features for Fold 0 =====


  model = create_fn(


Extracting Fold 0 Features:   0%|          | 0/268 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(enabled=(device.type == 'cuda')):


RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.FloatTensor) should be the same