In [1]:
import h5py
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler
import torch.nn as nn 
import torch.optim as optim 
import os

# Step 1: Load and process the MEG data
file_path = r"D:\BTP\sub-1_ses-1_task-bcimici_meg.mat"

with h5py.File(file_path, 'r') as f:
    data = f['dataMAT']
    trial_refs = data['trial']
    trial_info_array = data['trialinfo'][:]  # Directly load trialinfo
    
    trials_list = []
    
    for i in range(200):
        trial_ref = trial_refs[i,0]
        trial_data = f[trial_ref][2001:7001]  # (channels, timepoints)
        trials_list.append(trial_data)

# Convert trials list to array
trials_array = np.array(trials_list)
print(trial_info_array.shape)  # Should now show (200, 1)

# Save to .npy files
np.save(r"meg_trials.npy", trials_array)
np.save(r"meg_trial_info.npy", trial_info_array)

# Step 2: Create a custom dataset class
class EEGDataset(Dataset):
    def __init__(self, trials, labels):
        # Trials: (num_trials, channels, timepoints)
        # Labels: (num_trials,)
        
        # Normalize data per channel
        self.data = torch.tensor(trials, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)
        
        # Map original labels to 0-3
        self.label_mapping = {1: 0, 2: 1, 3: 2, 4: 3}
        self.labels = torch.tensor([self.label_mapping[int(x)] for x in labels])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# Step 3: Load the processed data
trials = np.load('meg_trials.npy')  # Shape (200, 5000, 306)
labels = np.load('meg_trial_info.npy')
print(f"Trials shape: {trials.shape}")
print(f"Labels original shape: {labels.shape}")
labels = labels.reshape(-1)  # Flatten labels for stratification
print(f"Labels flattened shape: {labels.shape}")

# Step 4: Define the LTC network model
class LTC_Cell(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(LTC_Cell, self).__init__()
        self.hidden_dim = hidden_dim
        self.W_xh = nn.Linear(input_dim, hidden_dim)
        self.W_hh = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.W_tau = nn.Linear(hidden_dim, hidden_dim)
        
    def forward(self, x, h):
        tau = torch.sigmoid(self.W_tau(h)) + 0.1
        dh = -h / tau + torch.tanh(self.W_xh(x) + self.W_hh(h))
        return h + 0.1 * dh

class LTC_Transformer(nn.Module):
    def __init__(self, input_dim=22, ltc_hidden_dim=64, num_classes=4, 
                 num_ltc_layers=1, num_transformer_layers=1, nhead=2, dropout=0.5):
        super().__init__()
        self.input_dim = input_dim
        self.ltc_hidden_dim = ltc_hidden_dim
        self.num_ltc_layers = num_ltc_layers
        
        # LTC layers
        self.ltc_layers = nn.ModuleList([
            LTC_Cell(input_dim if i==0 else ltc_hidden_dim, ltc_hidden_dim)
            for i in range(num_ltc_layers)
        ])
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=ltc_hidden_dim, 
            nhead=nhead, 
            dim_feedforward=ltc_hidden_dim*2,
            dropout=dropout,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer, 
            num_layers=num_transformer_layers
        )
        
        # Classifier head
        self.classifier = nn.Sequential(
            nn.Linear(ltc_hidden_dim, 32),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(32, num_classes)
        )
    
    def forward(self, x):
        # x shape: [batch_size, seq_len, input_dim]
        batch_size, seq_len, _ = x.size()
        
        # Initialize hidden states
        hiddens = [torch.zeros(batch_size, self.ltc_hidden_dim, device=x.device) 
                  for _ in range(self.num_ltc_layers)]
        
        # Process sequence through LTC layers
        all_hidden = []
        for t in range(seq_len):
            x_t = x[:, t, :]
            for layer_idx in range(self.num_ltc_layers):
                hiddens[layer_idx] = self.ltc_layers[layer_idx](
                    x_t if layer_idx == 0 else hiddens[layer_idx-1],
                    hiddens[layer_idx]
                )
            all_hidden.append(hiddens[-1])
        
        # Stack hidden states for transformer input
        hidden_stack = torch.stack(all_hidden, dim=1)  # [batch_size, seq_len, ltc_hidden_dim]
        
        # Pass through transformer
        transformer_out = self.transformer(hidden_stack)
        
        # Global average pooling
        pooled = transformer_out.mean(dim=1)
        
        # Classification
        return self.classifier(pooled)

# Step 5: Prepare for training with 5-fold cross-validation
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Prepare the full dataset
full_dataset = EEGDataset(trials, labels)

# KFold cross-validation
kf = KFold(n_splits=5, shuffle=True, random_state=42)

# Setup
num_epochs = 100
batch_size = 32
criterion = nn.CrossEntropyLoss()

# Store results across folds
fold_results = []

# Step 6: Implement 5-fold cross-validation training
for fold, (train_idx, val_idx) in enumerate(kf.split(np.arange(len(full_dataset)))):
    print(f'Fold {fold+1}/5')
    print(f'Train size: {len(train_idx)}, Validation size: {len(val_idx)}')
    
    # Create data subsets for this fold
    train_subset = Subset(full_dataset, train_idx)
    val_subset = Subset(full_dataset, val_idx)
    
    # Create dataloaders
    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_subset, batch_size=batch_size)
    
    # Initialize a fresh model for this fold
    model = LTC_Transformer(
    input_dim=306,               # Number of EEG channels
    ltc_hidden_dim=64,          # LTC hidden dimension (reduced)
    num_classes=4,              # Number of classes to predict
    num_ltc_layers=1,           # Single LTC layer for simplicity
    num_transformer_layers=1,   # Single transformer layer
    nhead=2,                    # Fewer attention heads
    dropout=0.5                 # Dropout for regularization
    ).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5)
    
    # Training variables
    best_val_acc = 0
    previous_val_acc = None
    same_acc_streak = 0
    val_acc_list = []
    
    # Training loop for this fold
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        
        for data, labels in train_loader:
            data, labels = data.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, labels)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        
        # Validation
        model.eval()
        val_loss, correct, total = 0, 0, 0
        
        with torch.no_grad():
            for data, labels in val_loader:
                data, labels = data.to(device), labels.to(device)
                outputs = model(data)
                val_loss += criterion(outputs, labels).item()
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        val_acc = 100 * correct / total
        val_loss /= len(val_loader)
        scheduler.step(val_loss)
        val_acc_list.append(val_acc)
        
        # Early stopping logic
        if previous_val_acc is None or val_acc != previous_val_acc:
            same_acc_streak = 0
        else:
            same_acc_streak += 1
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            # Save model for this fold
            torch.save(model.state_dict(), f'meg_ltc_TRAN_model_fold{fold+1}.pth')
        
        print(f'Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | Same accuracy streak: {same_acc_streak}/10')
        
        if same_acc_streak >= 10:
            print(f'Early stopping triggered after {epoch+1} epochs!')
            break
        
        previous_val_acc = val_acc
    
    # Calculate and store the average of the last 5 validation accuracies for this fold
    avg_val_acc = sum(val_acc_list[-5:]) / min(5, len(val_acc_list))
    fold_results.append(best_val_acc)
    print(f'Fold {fold+1} completed with best validation accuracy: {best_val_acc:.2f}%')
    print(f'Average of last 5 epochs: {avg_val_acc:.2f}%')

# Step 7: Report overall cross-validation results
avg_acc = sum(fold_results) / len(fold_results)
std_acc = np.std(fold_results)
print("\n===== Cross-Validation Results =====")
print(f'Individual fold accuracies: {[f"{acc:.2f}%" for acc in fold_results]}')
print(f'Average validation accuracy across 5 folds: {avg_acc:.2f}% ± {std_acc:.2f}%')
print(f'Best fold accuracy: {max(fold_results):.2f}%')


(1, 200)
Trials shape: (200, 4999, 306)
Labels original shape: (1, 200)
Labels flattened shape: (200,)
Using device: cuda
Fold 1/5
Train size: 160, Validation size: 40
Epoch 1/100 | Train Loss: 1.4263 | Val Loss: 1.5151 | Val Acc: 10.00% | Same accuracy streak: 0/10
Epoch 2/100 | Train Loss: 1.4180 | Val Loss: 1.4242 | Val Acc: 10.00% | Same accuracy streak: 1/10
Epoch 3/100 | Train Loss: 1.3965 | Val Loss: 1.3991 | Val Acc: 32.50% | Same accuracy streak: 0/10
Epoch 4/100 | Train Loss: 1.3953 | Val Loss: 1.4086 | Val Acc: 32.50% | Same accuracy streak: 1/10
Epoch 5/100 | Train Loss: 1.3843 | Val Loss: 1.4220 | Val Acc: 10.00% | Same accuracy streak: 0/10
Epoch 6/100 | Train Loss: 1.3823 | Val Loss: 1.4374 | Val Acc: 10.00% | Same accuracy streak: 1/10
Epoch 7/100 | Train Loss: 1.3946 | Val Loss: 1.4091 | Val Acc: 10.00% | Same accuracy streak: 2/10
Epoch 8/100 | Train Loss: 1.3881 | Val Loss: 1.3938 | Val Acc: 27.50% | Same accuracy streak: 0/10
Epoch 9/100 | Train Loss: 1.3958 | Val L