In [1]:
import os
import time
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import math
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch import einsum
from tqdm import tqdm
from einops import rearrange, repeat
from transformers import get_linear_schedule_with_warmup
from sklearn.metrics import confusion_matrix, classification_report
try:
    from torchinfo import summary
except:
    !pip install -q torchinfo
    from torchinfo import summary

In [2]:
class PatchEmbedding3D(nn.Module):
    
    def __init__(self, in_channels:int, embedding_dim:int, patch_size:int):
        super().__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv3d(
            in_channels, embedding_dim,
            kernel_size=(1, patch_size, patch_size),
            stride=(1, patch_size, patch_size)
        )
    
    def forward(self, x):
        x = x.permute(0, 2, 1, 3, 4).contiguous()   
        
        x = self.proj(x)

        B, D, T, H, W = x.shape
        x = rearrange(x, 'b d t h w -> b (t h w) d', b = B, d = D, t = T, h = H, w = W)
        
        return x              

In [3]:
class LayerNormalization(nn.Module):

    def __init__(self, embedding_dim: int, eps:float=10**-6) -> None:
        
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(embedding_dim)) 
        self.bias = nn.Parameter(torch.zeros(embedding_dim)) 

    def forward(self, x):

        mean = x.mean(dim = -1, keepdim = True)
        std = x.std(dim = -1, keepdim = True) 
        
        return self.alpha * (x - mean) / (std + self.eps) + self.bias

In [4]:
class FeedForwardBlock(nn.Module):
    
    def __init__(self, embedding_dim: int, mlp_size: int, dropout: float):
        
        super().__init__()
        self.linear_1 = nn.Linear(embedding_dim, mlp_size) 
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(mlp_size, embedding_dim) 

    def forward(self, x):
        return self.linear_2(self.dropout(F.gelu(self.linear_1(x))))

In [5]:
class Attention(nn.Module):

    def __init__(self, embedding_dim: int, num_heads: int):

        super().__init__()
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.d_k = embedding_dim // num_heads
        assert embedding_dim % num_heads == 0, "d_model must be divisible by h"

        self.qkv = nn.Linear(embedding_dim, embedding_dim * 3, bias=False)
        self.w_o = nn.Linear(embedding_dim, embedding_dim, bias=False)

    def forward(self, x):
        
        B, N, D = x.shape
        
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, D // self.num_heads).permute(2, 0, 3, 1, 4)
        query, key, value = qkv[0], qkv[1], qkv[2]
        
        attention = (query @ key.transpose(-2, -1)) / math.sqrt(self.d_k)
        attention = F.softmax(attention, dim=-1)

        x = (attention @ value).transpose(1, 2).reshape(B, N, D)
        
        return self.w_o(x)        

In [6]:
class EncoderBlock(nn.Module):

    def __init__(self, embedding_dim:int, num_heads: int, mlp_size: int, batch_size: int, num_frames: int, num_patches: int, dropout: float):
        
        super().__init__()
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.mlp_size = mlp_size
        self.batch_size = batch_size
        self.num_frames = num_frames
        self.num_patches = num_patches
        self.dropout = nn.Dropout(dropout)

        self.norm1 = LayerNormalization(embedding_dim)
        self.spatial_attention = Attention(embedding_dim, num_heads)
        self.temporal_attention = Attention(embedding_dim, num_heads)
        self.dropout_1 = nn.Dropout(dropout)
        
        self.norm2 = LayerNormalization(embedding_dim)  
        self.ffn = FeedForwardBlock(embedding_dim, mlp_size, dropout)
        self.dropout_2 = nn.Dropout(dropout)
        

    def forward(self, x): 

        B = self.batch_size
        T = self.num_frames
        S = self.num_patches
        
        residual = x
        x = self.norm1(x)
        
        X = x[:,1:,:]
        
        init_cls_token = x[:,0,:].unsqueeze(1)
        
        # Spatial Attention
        t_cls_token = init_cls_token.repeat(1, T, 1)
        t_cls_token = rearrange(t_cls_token, 'b t m -> (b t) m', b = B, t = T).unsqueeze(1)

        xs = rearrange(X, 'b (s t) d -> (b t) s d', b = B, t = T, s = S)
        xs = torch.cat((t_cls_token, xs), 1)

        spatial_attention = self.spatial_attention(xs)

        t_cls_token = spatial_attention[:,0,:]
        t_cls_token = rearrange(t_cls_token, '(b t) m -> b t m', b = B, t = T)
        t_cls_token = torch.mean(t_cls_token, 1, True)

        spatial_attention = spatial_attention[:,1:,:]
        spatial_attention = rearrange(spatial_attention, '(b t) s d -> b (t s) d', b = B, t = T, s = S)
        spatial_attention = torch.cat((t_cls_token, spatial_attention), 1)
        
        # Temporal Attention
        s_cls_token = init_cls_token.repeat(1, S, 1)
        s_cls_token = rearrange(s_cls_token, 'b s m -> (b s) m', b = B, s = S).unsqueeze(1)

        xt = rearrange(X, 'b (s t) d -> (b s) t d', b = B, s = S, t = T)
        xt = torch.cat((s_cls_token, xt), 1)

        temporal_attention = self.temporal_attention(xt)
        
        s_cls_token = temporal_attention[:,0,:]
        s_cls_token = rearrange(s_cls_token, '(b s) m -> b s m', b = B, s = S)
        s_cls_token = torch.mean(s_cls_token, 1, True)

        temporal_attention = temporal_attention[:,1:,:]
        temporal_attention = rearrange(temporal_attention, '(b s) t d -> b (t s) d', b = B, t = T, s = S)
        temporal_attention = torch.cat((s_cls_token, temporal_attention),1)
        

        combined = spatial_attention + temporal_attention
        attn_out = combined + residual
        attn_out = self.dropout_1(attn_out)
        
        out = attn_out + self.dropout_2(self.ffn(self.norm2(attn_out)))
        
        return out

In [7]:
class Encoder(nn.Module):
    
    def __init__(self, embedding_dim: int, num_layers: int, num_heads: int, mlp_size: int, batch_size: int, num_frames: int, num_patches: int, dropout: float = 0.1):

        super().__init__()
        self.layers = nn.ModuleList([
            EncoderBlock(
                embedding_dim = embedding_dim,
                num_heads = num_heads,
                mlp_size = mlp_size,
                batch_size = batch_size,
                num_frames = num_frames,
                num_patches = num_patches,
                dropout = dropout
            ) for _ in range(num_layers)
        ])
        self.norm = LayerNormalization(embedding_dim) 

    def forward(self, x):

        for layer in self.layers:
            x = layer(x)

        return self.norm(x)

In [9]:
class SpatioTemporalTransformer(nn.Module):

    def __init__(self, 
                 feature_size:int = 64,
                 num_frames:int = 50,
                 batch_size:int = 16,
                 in_channels:int = 1,
                 patch_size:int = 16,
                 num_layers:int = 12,
                 embedding_dim:int = 768,
                 mlp_size:int = 3072,
                 num_heads:int = 12,
                 dropout:float = 0.1,
                 embed_dropout:float = 0.1,
                 num_classes:int = 6):    
        super().__init__()

        assert feature_size % patch_size == 0, f'Image size {feature_size} is not divisible by the patch size {patch_size}'

        self.num_patches = (feature_size * feature_size) // patch_size**2
        self.batch_size = batch_size
        self.num_frames = num_frames
        
        self.class_embedding = nn.Parameter(data = torch.randn(1, 1, embedding_dim), requires_grad = True)
        
        self.patch_embedding = PatchEmbedding3D(
            in_channels = in_channels, 
            embedding_dim = embedding_dim, 
            patch_size = patch_size
        )
        
        self.positional_embedding = nn.Parameter(data = torch.randn(1, (self.num_patches * self.num_frames) + 1, embedding_dim), requires_grad = True)
        
        self.embedding_dropout = nn.Dropout(embed_dropout)
        
        self.transformer_encoder = Encoder(
            embedding_dim = embedding_dim,
            num_layers = num_layers,
            num_heads = num_heads,
            mlp_size = mlp_size,
            batch_size = batch_size,
            num_frames = num_frames,
            num_patches = self.num_patches,
            dropout = dropout
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(in_features = embedding_dim, out_features = num_classes)
        )


    def forward(self, x):
        
        x = self.patch_embedding(x)
        
        class_token = self.class_embedding.expand(self.batch_size, -1, -1)
        
        x = torch.cat((class_token, x), dim=1)
        
        x = self.positional_embedding + x
        
        x = self.embedding_dropout(x)
        
        x = self.transformer_encoder(x)

        x = self.classifier(x[:,0,:])

        return x

In [8]:
class EEGDataset(Dataset):
    def __init__(self, npz_files):
        self.npz_files = npz_files

    def __len__(self):
        return len(self.npz_files)

    def __getitem__(self, idx):
        data = np.load(self.npz_files[idx])
        features = data['features']         
        label = data['label']      
        features = torch.tensor(features, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.long).squeeze()
        return features, label

In [9]:
train_dir = "/kaggle/input/eeg-spatio-temporal-feature-map-dataset/Train/"
val_dir = "/kaggle/input/eeg-spatio-temporal-feature-map-dataset/Validation/"
test_dir = "/kaggle/input/eeg-spatio-temporal-feature-map-dataset/Test/"

train_files = [os.path.join(train_dir, f) for f in sorted(os.listdir(train_dir)) if f.endswith(".npz")]
test_files = [os.path.join(test_dir, f) for f in sorted(os.listdir(test_dir)) if f.endswith(".npz")]
val_files = [os.path.join(val_dir, f) for f in sorted(os.listdir(val_dir)) if f.endswith(".npz")]

train_dataset = EEGDataset(train_files)
test_dataset = EEGDataset(test_files)
val_dataset = EEGDataset(val_files)

batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

In [24]:
class Trainer:
    
    def __init__(self, model, train_loader, test_loader, device, config):
        
        self.model = model.to(device)
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.device = device
        self.config = config
        
        self.criterion = nn.CrossEntropyLoss()
        
        self.optimizer = optim.AdamW(
            model.parameters(),
            lr = config['learning_rate'],
            betas = (0.9, 0.999),
            weight_decay = 1e-2
        )
        
        total_steps = config['num_epochs'] * len(train_loader)
        warmup_steps = int(config.get('warmup_ratio', 0.1) * total_steps)
        
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps = warmup_steps,
            num_training_steps = total_steps
        )
        
        self.train_loss_history = []
        self.test_loss_history = []
        self.train_acc_history = []
        self.test_acc_history = []
    
        self.save_dir = config['save_dir']
        os.makedirs(self.save_dir, exist_ok=True)

    
    def train_epoch(self):
        
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for batch_idx, (data, targets) in enumerate(tqdm(self.train_loader, desc = "Train")):
            data = data.to(self.device)
            targets = targets.to(self.device)
            
            outputs = self.model(data)
            loss = self.criterion(outputs, targets)
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
        
        epoch_loss = running_loss / len(self.train_loader)
        epoch_acc = 100. * correct / total
        
        self.train_loss_history.append(epoch_loss)
        self.train_acc_history.append(epoch_acc)
        
        return epoch_loss, epoch_acc

    
    def test_epoch(self):
        
        self.model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for batch_idx, (data, targets) in enumerate(tqdm(self.test_loader, desc = "Test")):
                data = data.to(self.device)
                targets = targets.to(self.device)
                
                outputs = self.model(data)
                loss = self.criterion(outputs, targets)
                
                running_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        
        epoch_loss = running_loss / len(self.test_loader)
        epoch_acc = 100. * correct / total
        
        self.test_loss_history.append(epoch_loss)
        self.test_acc_history.append(epoch_acc)
        
        return epoch_loss, epoch_acc

    
    def save_checkpoint(self, epoch):
        
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'train_loss_history': self.train_loss_history,
            'test_loss_history': self.test_loss_history,
            'train_acc_history': self.train_acc_history,
            'test_acc_history': self.test_acc_history,
            'config': self.config
        }
        
        torch.save(checkpoint, os.path.join(self.save_dir, f'checkpoint_epoch_{epoch}.pth'))
            
        metrics = {
            'train_loss': self.train_loss_history,
            'test_loss': self.test_loss_history,
            'train_acc': self.train_acc_history,
            'test_acc': self.test_acc_history
        }
        with open(os.path.join(self.save_dir, 'training_metrics.json'), 'w') as f:
            json.dump(metrics, f)

    
    def train(self, num_epochs):
        
        start_time = time.time()
        
        for epoch in range(4, num_epochs + 1):
            print(f"\nEpoch {epoch}/{num_epochs}")
            
            train_loss, train_acc = self.train_epoch()
            test_loss, test_acc = self.test_epoch()
            
            print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
            print(f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%")

            self.save_checkpoint(epoch)
            
        total_time = time.time() - start_time
        print(f"\nTraining completed in {total_time//60:.0f}m {total_time%60:.0f}s")   

In [22]:
config = {
    'feature_size': 64,
    'num_frames': 50,
    'batch_size': 16,
    'in_channels': 1,
    'patch_size': 16,
    'num_layers': 6, # 12
    'embedding_dim': 768,
    'mlp_size': 3072,
    'num_heads': 12,
    'dropout': 0.1,
    'embed_dropout': 0.1,
    'num_classes': 6,
    'learning_rate': 1e-3,
    'save_dir': '/kaggle/working/checkpoints',
    'num_epochs': 7,
    'warmup_ratio': 0.1
}
    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = SpatioTemporalTransformer(
    feature_size = config['feature_size'],
    num_frames = config['num_frames'],
    batch_size = config['batch_size'],
    in_channels = config['in_channels'],
    patch_size = config['patch_size'],
    num_layers = config['num_layers'],
    embedding_dim = config['embedding_dim'],
    mlp_size = config['mlp_size'],
    num_heads = config['num_heads'],
    dropout = config['dropout'],
    embed_dropout = config['embed_dropout'],
    num_classes = config['num_classes']
)
    
trainer = Trainer(
    model = model,
    train_loader = train_loader,
    test_loader = test_loader,
    device = device,
    config = config
)

In [26]:
def load_checkpoint(path, model, optimizer = None, device = 'cpu'):
    checkpoint = torch.load(path, map_location = device)
    model.load_state_dict(checkpoint['model_state_dict'])
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['epoch'], checkpoint['train_loss_history'], checkpoint['test_loss_history'], checkpoint['train_acc_history'], checkpoint['test_acc_history']


# model = SpatioTemporalTransformer(...)
# optimizer = optim.Adam(model.parameters(), lr = config['learning_rate'], betas = (0.9, 0.999), weight_decay = 0.1)
optimizer = optim.AdamW(
            model.parameters(),
            lr = config['learning_rate'],
            betas = (0.9, 0.999),
            weight_decay = 1e-2
        )

epoch, train_loss, test_loss, train_acc, test_acc = load_checkpoint('/kaggle/input/checkpoint3/checkpoint_epoch_3.pth', model, optimizer, device)

  checkpoint = torch.load(path, map_location = device)


In [27]:
device

device(type='cuda')

In [28]:
trainer.train(config['num_epochs'])


Epoch 4/7


Train: 100%|██████████| 1800/1800 [30:38<00:00,  1.02s/it]
Test: 100%|██████████| 450/450 [02:38<00:00,  2.83it/s]


Train Loss: 1.7525 | Train Acc: 22.66%
Test Loss: 1.8130 | Test Acc: 14.64%

Epoch 5/7


Train: 100%|██████████| 1800/1800 [30:18<00:00,  1.01s/it]
Test: 100%|██████████| 450/450 [03:47<00:00,  1.98it/s]


Train Loss: 1.7935 | Train Acc: 17.92%
Test Loss: 1.7614 | Test Acc: 19.43%

Epoch 6/7


Train: 100%|██████████| 1800/1800 [31:25<00:00,  1.05s/it]
Test: 100%|██████████| 450/450 [03:05<00:00,  2.42it/s]


Train Loss: 1.7455 | Train Acc: 23.24%
Test Loss: 1.7236 | Test Acc: 25.35%

Epoch 7/7


Train: 100%|██████████| 1800/1800 [28:48<00:00,  1.04it/s]
Test: 100%|██████████| 450/450 [02:29<00:00,  3.01it/s]


Train Loss: 1.7484 | Train Acc: 22.83%
Test Loss: 1.7325 | Test Acc: 24.11%

Training completed in 133m 16s


In [17]:
!zip -r checkpoints_123.zip checkpoints

  adding: checkpoints/ (stored 0%)
  adding: checkpoints/checkpoint_epoch_3.pth (deflated 8%)
  adding: checkpoints/checkpoint_epoch_2.pth (deflated 8%)
  adding: checkpoints/checkpoint_epoch_1.pth (deflated 8%)
  adding: checkpoints/training_metrics.json (deflated 45%)


In [18]:
from IPython.display import FileLink 
FileLink(r'checkpoints_123.zip')

In [19]:
import os
import shutil

folder_path = '/kaggle/working/checkpoints'

if os.path.exists(folder_path):
    if os.path.isdir(folder_path):
        shutil.rmtree(folder_path)  # Deletes the folder and all its contents
        print(f"{folder_path} has been deleted.")
    else:
        print(f"{folder_path} is not a directory.")
else:
    print(f"{folder_path} does not exist.")

/kaggle/working/checkpoints has been deleted.


In [20]:
import os

file_path = '/kaggle/working/checkpoints_123.zip'

if os.path.exists(file_path):
    os.remove(file_path)
    print(f"{file_path} has been deleted.")
else:
    print(f"{file_path} does not exist.") 

/kaggle/working/checkpoints_123.zip has been deleted.


In [None]:
def load_checkpoint(path, model, optimizer = None, device = 'cpu'):
    checkpoint = torch.load(path, map_location = device)
    model.load_state_dict(checkpoint['model_state_dict'])
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['epoch'], checkpoint['train_loss_history'], checkpoint['test_loss_history'], checkpoint['train_acc_history'], checkpoint['test_acc_history']


# model = SpatioTemporalTransformer(...)
# optimizer = optim.Adam(model.parameters(), lr = config['learning_rate'], betas = (0.9, 0.999), weight_decay = 0.1)

epoch, train_loss, test_loss, train_acc, test_acc = load_checkpoint('/kaggle/working/checkpoints/.. .pth', model, optimizer, device)

In [None]:
def validate_model(model, val_loader, device, num_classes = 6):
    
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    
    all_preds = []
    all_labels = []
    
    criterion = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for batch_idx, (data, targets) in enumerate(tqdm(val_loader, desc="Validation")):
            
            data = data.to(device)
            targets = targets.to(device)
            outputs = model(data)
            loss = criterion(outputs, targets)
            
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(targets.cpu().numpy())
    
    val_loss /= len(val_loader)
    val_acc = 100. * correct / total
    
    print(f"\nValidation Results:")
    print(f"Loss: {val_loss:.4f} | Accuracy: {val_acc:.2f}%")
    
    print("\nClassification Report:")
    print(classification_report(all_labels, all_preds, target_names=[f'Class {i}' for i in range(num_classes)]))
    
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=[f'Class {i}' for i in range(num_classes)],
                yticklabels=[f'Class {i}' for i in range(num_classes)])
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.show()
    
    return val_loss, val_acc

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = SpatioTemporalTransformer(...)

val_loss, val_acc = validate_model(model, val_loader, device)