In [1]:
!pip install torch_scatter torcheeg 


Collecting torch_scatter
  Downloading torch_scatter-2.1.2.tar.gz (108 kB)
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m108.0/108.0 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting torcheeg
  Downloading torcheeg-1.1.3.tar.gz (251 kB)
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m251.4/251.4 kB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting scipy<=1.10.1,>=1.7.3 (from torcheeg)
  Downloading scipy-1.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (58 kB)
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m58.9/58.9 kB[0m [31m4.1 MB/s[0m eta 

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from torcheeg.datasets import SEEDIVDataset
from torcheeg import transforms
import scipy.signal as signal
import random
import numpy as np

In [3]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(42)

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [5]:
def BandPassFilter(eeg_data):
    b, a = signal.butter(4, Wn=[1.0, 75.0], btype='bandpass', fs=200)
    return signal.filtfilt(b, a, eeg_data, axis=-1)

In [6]:
def Notch(eeg_data):
    b, a = signal.iirnotch(w0=50.0, Q=30.0, fs=200)
    return signal.filtfilt(b, a, eeg_data, axis=-1)

In [7]:
t_transform = transforms.Compose([
    transforms.Lambda(BandPassFilter),
    transforms.Lambda(Notch),
    transforms.BaselineRemoval(),
    transforms.MeanStdNormalize(),
    transforms.To2d()
    
])

In [8]:
import shutil, os

if os.path.exists('./tmp_out/seed_iv_augmented'):
    shutil.rmtree('./tmp_out/seed_iv_augmented')

In [9]:
window = 800
overlap_ratio = 0.3
step = int(window * (1 - overlap_ratio))
overlap_samples = window - step

dataset = SEEDIVDataset(
    io_path='./tmp_out/seed_iv_augmented',
    root_path='/kaggle/input/seed-iv/eeg_raw_data',
    offline_transform=t_transform,
    label_transform=transforms.Compose([
        transforms.Select('emotion'),
    ]),
    chunk_size=window,
    overlap=overlap_samples,
    io_mode='memory', 
    num_worker=1
)

[2025-12-02 13:39:38] INFO (torcheeg/MainThread) üîç | Processing EEG data. Processed EEG data has been cached to [92m./tmp_out/seed_iv_augmented[0m.
[2025-12-02 13:39:38] INFO (torcheeg/MainThread) ‚è≥ | Monitoring the detailed processing of a record for debugging. The processing of other records will only be reported in percentage to keep it clean.
[PROCESS]:   0%|          | 0/45 [00:00<?, ?it/s]
[RECORD /kaggle/input/seed-iv/eeg_raw_data/1/4_20151111.mat]: 0it [00:00, ?it/s][A
[RECORD /kaggle/input/seed-iv/eeg_raw_data/1/4_20151111.mat]: 1it [00:03,  3.98s/it][A
[RECORD /kaggle/input/seed-iv/eeg_raw_data/1/4_20151111.mat]: 25it [00:04,  8.55it/s][A
[RECORD /kaggle/input/seed-iv/eeg_raw_data/1/4_20151111.mat]: 49it [00:04, 19.45it/s][A
[RECORD /kaggle/input/seed-iv/eeg_raw_data/1/4_20151111.mat]: 75it [00:04, 34.55it/s][A
[RECORD /kaggle/input/seed-iv/eeg_raw_data/1/4_20151111.mat]: 101it [00:04, 53.05it/s][A
[RECORD /kaggle/input/seed-iv/eeg_raw_data/1/4_20151111.mat]: 128

In [10]:
import pandas as pd

# 1. Get the metadata DataFrame
df = dataset.info

# 2. Count the segments for each emotion
# 0: Neutral, 1: Sad, 2: Fear, 3: Happy
counts = df['emotion'].value_counts().sort_index()
total = len(df)

print(f"Total Segments: {total}")
print("-" * 30)
print("Count per Emotion:")
print(counts)

print("-" * 30)
print("Percentage per Emotion:")
percentages = (counts / total) * 100
print(percentages.round(2))

# 3. Check for Imbalance
# If the difference between max and min is > 10%, we might need a WeightedSampler
max_pct = percentages.max()
min_pct = percentages.min()

if (max_pct - min_pct) > 10:
    print(f"\n‚ö†Ô∏è WARNING: Data is IMBALANCED (Diff: {max_pct - min_pct:.2f}%)")
    print("Consider using a WeightedRandomSampler.")

Total Segments: 53235
------------------------------
Count per Emotion:
emotion
0    14445
1    14460
2    13095
3    11235
Name: count, dtype: int64
------------------------------
Percentage per Emotion:
emotion
0    27.13
1    27.16
2    24.60
3    21.10
Name: count, dtype: float64


In [11]:
from torch.utils.data import Subset
import random

# SEED-IV has 24 trials (videos) per session
all_trial_ids = list(range(1, 25))

random.seed(42)
test_trial_ids = random.sample(all_trial_ids, 5)  # 5 trials for test
train_trial_ids = [t for t in all_trial_ids if t not in test_trial_ids]  # 19 trials for train

# Get indices
train_indices = df[df['trial_id'].isin(train_trial_ids)].index.tolist()
test_indices = df[df['trial_id'].isin(test_trial_ids)].index.tolist()

# Create Subsets
train_set = Subset(dataset, train_indices)
test_set = Subset(dataset, test_indices)

print(f"Train samples: {len(train_set)}, Test samples: {len(test_set)}")

Train samples: 41700, Test samples: 11535


In [12]:
from torch.utils.data import DataLoader, WeightedRandomSampler
import torch

train_labels = [df.iloc[i]['emotion'] for i in train_indices]
class_counts = pd.Series(train_labels).value_counts().to_dict()
weights = [1.0 / class_counts[label] for label in train_labels]

sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)

train_loader = DataLoader(train_set, batch_size=64, sampler=sampler)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)

In [13]:
import torch.nn as nn
from torchvision.ops import SqueezeExcitation
input_tensor = dataset[0][0]
n_channels = input_tensor.shape[1]  
samples = input_tensor.shape[2]    

n_classes = 4                          
print(f"Electrodes (Height): {n_channels}, Time (Width): {samples}")
model = nn.Sequential(
#1
    
    nn.Conv2d(1, 8, kernel_size=(1, 100), padding=(0, 32), stride=(1, 4), bias=False),

    nn.BatchNorm2d(8),
    nn.LeakyReLU(0.2),
                                           

    nn.Conv2d(8, 16, kernel_size=(n_channels, 1),stride=(1,2), bias=False),           
    nn.BatchNorm2d(16),
    nn.LeakyReLU(0.1),
    #nn.AvgPool2d((1, 2)),
    
      
#2
    nn.Conv2d(16, 32, kernel_size=(1, 32), padding=(0, 4), bias=False),
    nn.BatchNorm2d(32),
    nn.LeakyReLU(0.1),
    nn.AvgPool2d((1, 4)),
    nn.Dropout(0.3),  

#3
    nn.Conv2d(32, 64, kernel_size=(1, 16), padding=(0, 2), bias=False),
    nn.BatchNorm2d(64),
    nn.LeakyReLU(0.1),
    nn.Dropout(0.5),

#4
    nn.Conv2d(64, 128, kernel_size=(1, 4), padding=(0, 1), bias=False),
    nn.BatchNorm2d(128),
    nn.LeakyReLU(0.1),

    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Flatten(),

    nn.Linear(128, 64),                                                  
    nn.LeakyReLU(0.1),
    nn.Dropout(0.3),                                                      

    nn.Linear(64, n_classes)


).to(device)

Electrodes (Height): 62, Time (Width): 800


In [14]:
import torch.optim as optim
num_epochs=40
patience = 5
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(                     
    model.parameters(),lr=3e-4,weight_decay=1e-4                        
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=num_epochs,eta_min=1e-6                             
)

In [15]:
best_acc = 0.0
epochs_improved = 0 

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    
    for X, y in train_loader:
        X, y = X.to(device).float(), y.to(device)
        optimizer.zero_grad()
        outputs = model(X)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item() * X.size(0)
        _, predicted = outputs.max(1)
        total += y.size(0)
        correct += predicted.eq(y).sum().item()
    
    train_acc = 100.*correct/total
    train_loss /= total
    
    # Validation / Test
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for X, y in test_loader:
            X, y = X.to(device).float(), y.to(device)
            outputs = model(X)
            loss = criterion(outputs, y)
            
            test_loss += loss.item() * X.size(0)
            _, predicted = outputs.max(1)
            total += y.size(0)
            correct += predicted.eq(y).sum().item()
    
    test_acc = 100.*correct/total
    test_loss /= total
    
    scheduler.step()
    
    print(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")
    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(model.state_dict(), 'best_model.pth')
        print(f"Saved New Best Model! (New Record: {best_acc:.2f}%)")
        epochs_no_improve = 0  
    else:
        epochs_no_improve += 1
    
    if epochs_no_improve >= patience:
        print(f"Early stopping triggered! No improvement in {patience} epochs.")
        break
    



Epoch [1/40] Train Loss: 1.2760, Train Acc: 39.73% | Test Loss: 1.6339, Test Acc: 27.12%
Saved New Best Model! (New Record: 27.12%)
Epoch [2/40] Train Loss: 1.1221, Train Acc: 51.28% | Test Loss: 1.3374, Test Acc: 41.46%
Saved New Best Model! (New Record: 41.46%)
Epoch [3/40] Train Loss: 1.0509, Train Acc: 55.72% | Test Loss: 1.4243, Test Acc: 41.10%
Epoch [4/40] Train Loss: 1.0027, Train Acc: 58.33% | Test Loss: 1.4385, Test Acc: 44.49%
Saved New Best Model! (New Record: 44.49%)
Epoch [5/40] Train Loss: 0.9750, Train Acc: 59.73% | Test Loss: 1.3349, Test Acc: 44.03%
Epoch [6/40] Train Loss: 0.9457, Train Acc: 61.33% | Test Loss: 1.4912, Test Acc: 38.96%
Epoch [7/40] Train Loss: 0.9307, Train Acc: 61.91% | Test Loss: 1.3782, Test Acc: 45.63%
Saved New Best Model! (New Record: 45.63%)
Epoch [8/40] Train Loss: 0.9075, Train Acc: 62.99% | Test Loss: 1.4353, Test Acc: 42.72%
Epoch [9/40] Train Loss: 0.8882, Train Acc: 63.92% | Test Loss: 1.4706, Test Acc: 43.35%
Epoch [10/40] Train Loss: 0