In [1]:
!pip install torch_scatter torcheeg 


Collecting torch_scatter
  Downloading torch_scatter-2.1.2.tar.gz (108 kB)
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m108.0/108.0 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting torcheeg
  Downloading torcheeg-1.1.3.tar.gz (251 kB)
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m251.4/251.4 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting scipy<=1.10.1,>=1.7.3 (from torcheeg)
  Downloading scipy-1.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (58 kB)
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m58.9/58.9 kB[0m [31m4.4 MB/s[0m eta

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from torcheeg.datasets import SEEDIVDataset
from torcheeg import transforms
import scipy.signal as signal
import random
import numpy as np

In [3]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(42)

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [5]:
def BandPassFilter(eeg_data):
    b, a = signal.butter(4, Wn=[1.0, 75.0], btype='bandpass', fs=200)
    return signal.filtfilt(b, a, eeg_data, axis=-1)

In [6]:
def Notch(eeg_data):
    b, a = signal.iirnotch(w0=50.0, Q=30.0, fs=200)
    return signal.filtfilt(b, a, eeg_data, axis=-1)

In [7]:
t_transform = transforms.Compose([
    transforms.Lambda(BandPassFilter),
    transforms.Lambda(Notch),
    transforms.BaselineRemoval(),
    transforms.MeanStdNormalize(),
    transforms.To2d()
    
])

In [8]:
import shutil, os

if os.path.exists('./tmp_out/seed_iv_augmented'):
    shutil.rmtree('./tmp_out/seed_iv_augmented')

In [9]:
window = 800
overlap_ratio = 0.3
step = int(window * (1 - overlap_ratio))
overlap_samples = window - step

dataset = SEEDIVDataset(
    io_path='./tmp_out/seed_iv_augmented',
    root_path='/kaggle/input/seed-iv/eeg_raw_data',
    offline_transform=t_transform,
    label_transform=transforms.Compose([
        transforms.Select('emotion'),
    ]),
    chunk_size=window,
    overlap=overlap_samples,
    io_mode='memory', 
    num_worker=1
)

[2025-12-02 19:26:32] INFO (torcheeg/MainThread) üîç | Processing EEG data. Processed EEG data has been cached to [92m./tmp_out/seed_iv_augmented[0m.
[2025-12-02 19:26:32] INFO (torcheeg/MainThread) ‚è≥ | Monitoring the detailed processing of a record for debugging. The processing of other records will only be reported in percentage to keep it clean.
[PROCESS]:   0%|          | 0/45 [00:00<?, ?it/s]
[RECORD /kaggle/input/seed-iv/eeg_raw_data/1/4_20151111.mat]: 0it [00:00, ?it/s][A
[RECORD /kaggle/input/seed-iv/eeg_raw_data/1/4_20151111.mat]: 1it [00:04,  4.54s/it][A
[RECORD /kaggle/input/seed-iv/eeg_raw_data/1/4_20151111.mat]: 30it [00:04,  9.06it/s][A
[RECORD /kaggle/input/seed-iv/eeg_raw_data/1/4_20151111.mat]: 59it [00:04, 20.78it/s][A
[RECORD /kaggle/input/seed-iv/eeg_raw_data/1/4_20151111.mat]: 89it [00:04, 36.43it/s][A
[RECORD /kaggle/input/seed-iv/eeg_raw_data/1/4_20151111.mat]: 119it [00:04, 55.90it/s][A
[RECORD /kaggle/input/seed-iv/eeg_raw_data/1/4_20151111.mat]: 148

In [10]:
import pandas as pd

# 1. Get the metadata DataFrame
df = dataset.info

# 2. Count the segments for each emotion
# 0: Neutral, 1: Sad, 2: Fear, 3: Happy
counts = df['emotion'].value_counts().sort_index()
total = len(df)

print(f"Total Segments: {total}")
print("-" * 30)
print("Count per Emotion:")
print(counts)

print("-" * 30)
print("Percentage per Emotion:")
percentages = (counts / total) * 100
print(percentages.round(2))

# 3. Check for Imbalance
# If the difference between max and min is > 10%, we might need a WeightedSampler
max_pct = percentages.max()
min_pct = percentages.min()

if (max_pct - min_pct) > 10:
    print(f"\n‚ö†Ô∏è WARNING: Data is IMBALANCED (Diff: {max_pct - min_pct:.2f}%)")
    print("Consider using a WeightedRandomSampler.")

Total Segments: 53235
------------------------------
Count per Emotion:
emotion
0    14445
1    14460
2    13095
3    11235
Name: count, dtype: int64
------------------------------
Percentage per Emotion:
emotion
0    27.13
1    27.16
2    24.60
3    21.10
Name: count, dtype: float64


In [11]:
from torch.utils.data import Subset
import random

# SEED-IV has 24 trials (videos) per session
all_trial_ids = list(range(1, 25))

random.seed(42)
test_trial_ids = random.sample(all_trial_ids, 5)  # 5 trials for test
train_trial_ids = [t for t in all_trial_ids if t not in test_trial_ids]  # 19 trials for train

# Get indices
train_indices = df[df['trial_id'].isin(train_trial_ids)].index.tolist()
test_indices = df[df['trial_id'].isin(test_trial_ids)].index.tolist()

# Create Subsets
train_set = Subset(dataset, train_indices)
test_set = Subset(dataset, test_indices)

print(f"Train samples: {len(train_set)}, Test samples: {len(test_set)}")

Train samples: 41700, Test samples: 11535


In [12]:
from torch.utils.data import DataLoader, WeightedRandomSampler
import torch

train_labels = [df.iloc[i]['emotion'] for i in train_indices]
class_counts = pd.Series(train_labels).value_counts().to_dict()
weights = [1.0 / class_counts[label] for label in train_labels]

sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)

train_loader = DataLoader(train_set, batch_size=32, sampler=sampler)
test_loader = DataLoader(test_set, batch_size=32, shuffle=False)

In [13]:
import torch.nn as nn
from torchvision.ops import SqueezeExcitation
input_tensor = dataset[0][0]
n_channels = input_tensor.shape[1]  
samples = input_tensor.shape[2]    

n_classes = 4                          
print(f"Electrodes (Height): {n_channels}, Time (Width): {samples}")
model = nn.Sequential(
     nn.Conv2d(1, 24, (1, 128), stride=(1, 2), padding=(0, 64), bias=False),
     nn.BatchNorm2d(24),
     nn.ELU(), 
     
     nn.Conv2d(24, 24, (n_channels, 1), groups=24, bias=False),
     nn.BatchNorm2d(24),
     nn.ELU(),

    # Pointwise mixing 
     nn.Conv2d(24, 48, (1, 1), bias=False),
     nn.BatchNorm2d(48),
     nn.ELU(),
     nn.Dropout(0.3),

    
     nn.Conv2d(48, 64, (1, 32), dilation=(1, 2), padding=(0, 32), bias=False),
     nn.BatchNorm2d(64),
     nn.ELU(),

     nn.Conv2d(64, 128, (1, 16), dilation=(1, 4), padding=(0, 30), bias=False), 
     nn.BatchNorm2d(128),
     nn.ELU(),
     nn.Dropout(0.4),

     nn.AdaptiveAvgPool2d((1, 1)),
     nn.Flatten(),

     nn.Linear(128, 64),
     nn.BatchNorm1d(64),
     nn.ELU(),
     nn.Dropout(0.5),
        
     nn.Linear(64, n_classes)
        
        

    ).to(device)

Electrodes (Height): 62, Time (Width): 800


In [14]:
import torch.optim as optim
num_epochs=40
patience = 10
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(                     
    model.parameters(),lr=3e-4,weight_decay=1e-4                        
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=num_epochs,eta_min=1e-6                             
)

In [15]:
best_acc = 0.0
epochs_improved = 0 

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    
    for X, y in train_loader:
        X, y = X.to(device).float(), y.to(device)
        optimizer.zero_grad()
        outputs = model(X)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item() * X.size(0)
        _, predicted = outputs.max(1)
        total += y.size(0)
        correct += predicted.eq(y).sum().item()
    
    train_acc = 100.*correct/total
    train_loss /= total
    
    # Validation / Test
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for X, y in test_loader:
            X, y = X.to(device).float(), y.to(device)
            outputs = model(X)
            loss = criterion(outputs, y)
            
            test_loss += loss.item() * X.size(0)
            _, predicted = outputs.max(1)
            total += y.size(0)
            correct += predicted.eq(y).sum().item()
    
    test_acc = 100.*correct/total
    test_loss /= total
    
    scheduler.step()
    
    print(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")
    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(model.state_dict(), 'best_model.pth')
        print(f"Saved New Best Model! (New Record: {best_acc:.2f}%)")
        epochs_no_improve = 0  
    else:
        epochs_no_improve += 1
    
    if epochs_no_improve >= patience:
        print(f"Early stopping triggered! No improvement in {patience} epochs.")
        break
    



Epoch [1/40] Train Loss: 1.2806, Train Acc: 40.52% | Test Loss: 1.3182, Test Acc: 42.41%
Saved New Best Model! (New Record: 42.41%)
Epoch [2/40] Train Loss: 1.1338, Train Acc: 50.88% | Test Loss: 1.3793, Test Acc: 40.16%
Epoch [3/40] Train Loss: 1.0730, Train Acc: 54.62% | Test Loss: 1.3316, Test Acc: 45.12%
Saved New Best Model! (New Record: 45.12%)
Epoch [4/40] Train Loss: 1.0297, Train Acc: 56.74% | Test Loss: 1.3052, Test Acc: 45.07%
Epoch [5/40] Train Loss: 0.9998, Train Acc: 58.41% | Test Loss: 1.2399, Test Acc: 47.28%
Saved New Best Model! (New Record: 47.28%)
Epoch [6/40] Train Loss: 0.9736, Train Acc: 59.64% | Test Loss: 1.2041, Test Acc: 48.93%
Saved New Best Model! (New Record: 48.93%)
Epoch [7/40] Train Loss: 0.9482, Train Acc: 61.08% | Test Loss: 1.2742, Test Acc: 46.17%
Epoch [8/40] Train Loss: 0.9275, Train Acc: 62.04% | Test Loss: 1.2513, Test Acc: 47.91%
Epoch [9/40] Train Loss: 0.9065, Train Acc: 63.23% | Test Loss: 1.1871, Test Acc: 50.78%
Saved New Best Model! (New 