In [1]:
# !rm -rf tmp_out

In [2]:
!pip install torcheeg torch-scatter torchvision

Collecting torcheeg
  Downloading torcheeg-1.1.3.tar.gz (251 kB)
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m251.4/251.4 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting torch-scatter
  Downloading torch_scatter-2.1.2.tar.gz (108 kB)
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m108.0/108.0 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting scipy<=1.10.1,>=1.7.3 (from torcheeg)
  Downloading scipy-1.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (58 kB)
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m58.9/58.9 kB[0m [31m3.5 MB/s[0m eta

In [3]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from torcheeg.datasets import SEEDIVDataset
from torcheeg import transforms
import scipy.signal as signal
import random
import copy
from torch import Tensor
from torchvision.models.googlenet import BasicConv2d
from typing import Callable, Optional


In [4]:
# 1. Setup Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [5]:
def BandPassFilter(eeg_data):
    b, a = signal.butter(4, Wn=[1.0, 75.0], btype='bandpass', fs=200)
    return signal.filtfilt(b, a, eeg_data, axis=-1)

In [6]:
def Notch(eeg_data):
    b, a = signal.iirnotch(w0=50.0, Q=30.0, fs=200)
    return signal.filtfilt(b, a, eeg_data, axis=-1)

In [7]:
# 2. Define Preprocessing
t_transform = transforms.Compose([
    transforms.Lambda(BandPassFilter),
    transforms.Lambda(Notch),
    transforms.BaselineRemoval(),
    transforms.MeanStdNormalize(),
    transforms.To2d()
])

In [8]:
# 3. Load Data
dataset = SEEDIVDataset(
    io_path='./tmp_out/seed_iv',
    root_path='/kaggle/input/seed-iv/eeg_raw_data',
    offline_transform=t_transform,
    label_transform=transforms.Compose([
        transforms.Select('emotion'),
    ]),
    chunk_size=800,  # 4 seconds
    num_worker=4
)

[2025-12-03 02:00:03] INFO (torcheeg/MainThread) üîç | Processing EEG data. Processed EEG data has been cached to [92m./tmp_out/seed_iv[0m.
[2025-12-03 02:00:03] INFO (torcheeg/MainThread) ‚è≥ | Monitoring the detailed processing of a record for debugging. The processing of other records will only be reported in percentage to keep it clean.
[PROCESS]:   0%|          | 0/45 [00:00<?, ?it/s]
[RECORD /kaggle/input/seed-iv/eeg_raw_data/1/4_20151111.mat]: 0it [00:00, ?it/s][A
[RECORD /kaggle/input/seed-iv/eeg_raw_data/1/4_20151111.mat]: 1it [00:03,  3.93s/it][A
[RECORD /kaggle/input/seed-iv/eeg_raw_data/1/4_20151111.mat]: 2it [00:04,  2.19s/it][A
[RECORD /kaggle/input/seed-iv/eeg_raw_data/1/4_20151111.mat]: 9it [00:05,  3.01it/s][A
[RECORD /kaggle/input/seed-iv/eeg_raw_data/1/4_20151111.mat]: 16it [00:05,  6.32it/s][A
[RECORD /kaggle/input/seed-iv/eeg_raw_data/1/4_20151111.mat]: 23it [00:05, 10.50it/s][A
[RECORD /kaggle/input/seed-iv/eeg_raw_data/1/4_20151111.mat]: 30it [00:05, 15.

In [9]:
# 1. Get the metadata DataFrame
df = dataset.info

# 2. Count the segments for each emotion
# 0: Neutral, 1: Sad, 2: Fear, 3: Happy
counts = df['emotion'].value_counts().sort_index()
total = len(df)

print(f"Total Segments: {total}")
print("-" * 30)
print("Count per Emotion:")
print(counts)

print("-" * 30)
print("Percentage per Emotion:")
percentages = (counts / total) * 100
print(percentages.round(2))

# 3. Check for Imbalance
# If the difference between max and min is > 10%, we might need a WeightedSampler
max_pct = percentages.max()
min_pct = percentages.min()

if (max_pct - min_pct) > 10:
    print(f"\n‚ö†Ô∏è WARNING: Data is IMBALANCED (Diff: {max_pct - min_pct:.2f}%)")
    print("Consider using a WeightedRandomSampler.")
else:
    print(f"\n‚úÖ Data is reasonably BALANCED (Diff: {max_pct - min_pct:.2f}%)")

Total Segments: 37575
------------------------------
Count per Emotion:
emotion
0    10170
1    10245
2     9225
3     7935
Name: count, dtype: int64
------------------------------
Percentage per Emotion:
emotion
0    27.07
1    27.27
2    24.55
3    21.12
Name: count, dtype: float64

‚úÖ Data is reasonably BALANCED (Diff: 6.15%)


In [10]:
# Split by Trial ID
# SEED-IV has 24 trials (videos) per session.
# 80% of VIDEOS for training (19 videos), 20% for testing (5 videos).
all_trial_ids = list(range(1, 25))

random.seed(42)
test_trial_ids = random.sample(all_trial_ids, 5)
train_trial_ids = [t for t in all_trial_ids if t not in test_trial_ids]

train_indices = df[df['trial_id'].isin(train_trial_ids)].index.tolist()
test_indices = df[df['trial_id'].isin(test_trial_ids)].index.tolist()

# Create Subsets & Loaders
train_set = Subset(dataset, train_indices)
test_set = Subset(dataset, test_indices)

train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)

### Inception Class

In [11]:
class Inception(nn.Module):
    def __init__(
        self,
        in_channels: int,
        ch1x1: int,
        ch3x3red: int,
        ch3x3: int,
        ch5x5red: int,
        ch5x5: int,
        pool_proj: int,
        conv_block: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        super().__init__()
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1 = conv_block(in_channels, ch1x1, kernel_size=1)

        self.branch2 = nn.Sequential(
            conv_block(in_channels, ch3x3red, kernel_size=1),
            conv_block(ch3x3red, ch3x3, kernel_size=3, padding=1),
        )

        self.branch3 = nn.Sequential(
            conv_block(in_channels, ch5x5red, kernel_size=1),
            # Here, kernel_size=3 instead of kernel_size=5 is a known bug.
            # Please see https://github.com/pytorch/vision/issues/906 for details.
            conv_block(ch5x5red, ch5x5, kernel_size=3, padding=1),
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
            conv_block(in_channels, pool_proj, kernel_size=1),
        )

    def _forward(self, x: Tensor) -> list[Tensor]:
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)

        outputs = [branch1, branch2, branch3, branch4]
        return outputs

    def forward(self, x: Tensor) -> Tensor:
        outputs = self._forward(x)
        return torch.cat(outputs, 1)

### GoogleNet class

In [12]:
def cprint(x):
    printEnabled = False
    if printEnabled:
        print(x)


class GoogLeNetLighter(nn.Module):
    __constants__ = ["transform_input"]

    def __init__(
        self,
        num_classes: int = 4,
        transform_input: bool = False,
        init_weights: Optional[bool] = None,
        blocks: Optional[list[Callable[..., nn.Module]]] = None,
        dropout: float = 0.4,  # Increased dropout slightly
    ) -> None:
        super().__init__()
        if blocks is None:
            conv_block = BasicConv2d
            inception_block = Inception
        else:
            if len(blocks) != 2:
                raise ValueError(f"blocks length should be 2 instead of {len(blocks)}")
            conv_block = blocks[0]
            inception_block = blocks[1]

        if init_weights is None:
            init_weights = True

        self.transform_input = transform_input

        self.conv1 = conv_block(1, 32, kernel_size=5, stride=1, padding=2)
        self.maxpool1 = nn.MaxPool2d((3, 5), stride=(1, 2), ceil_mode=True)

        self.conv2 = conv_block(32, 32, kernel_size=1)
        self.conv3 = conv_block(32, 96, kernel_size=3, padding=1)

        self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception3a = inception_block(96, 32, 48, 64, 8, 16, 16)
        self.inception3b = inception_block(128, 64, 64, 96, 16, 48, 32)

        self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(p=dropout)

        self.fc = nn.Linear(240, num_classes)

        if init_weights:
            for m in self.modules():
                if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                    torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=0.01, a=-2, b=2)
                elif isinstance(m, nn.BatchNorm2d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)

    def _forward(self, x: Tensor) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
        cprint("*" * 30)
        x = self.conv1(x)
        cprint(f"conv1       = {x.shape}")
        x = self.maxpool1(x)
        cprint(f"maxpool1    = {x.shape}")
        x = self.conv2(x)
        cprint(f"conv2       = {x.shape}")
        x = self.conv3(x)
        cprint(f"conv3       = {x.shape}")
        x = self.maxpool2(x)
        cprint(f"maxpool2    = {x.shape}")

        x = self.inception3a(x)
        cprint(f"inception3a = {x.shape}")
        x = self.inception3b(x)
        cprint(f"inception3b = {x.shape}")
        # x = self.maxpool3(x)
        # cprint(f"maxpool3    = {x.shape}")

        x = self.avgpool(x)  
        cprint(f"avgpool = {x.shape}")
        x = torch.flatten(x, 1)
        cprint(f"flatten = {x.shape}")
        x = self.dropout(x)
        cprint(f"dropout = {x.shape}")
        x = self.fc(x)
        cprint(f"fc = {x.shape}")

        return x

    def forward(self, x: Tensor) -> Tensor:
        x = self._forward(x)
        return x

In [13]:
model = GoogLeNetLighter(init_weights=True, dropout=0.9).to(device)

In [14]:
# 6. Training Loop
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

In [15]:
patience = 15
counter = 0
best_val_acc = 0.0
best_model_state = None

for epoch in range(100):
    model.train()
    train_loss = 0
    correct_train = 0
    total_train = 0
    
    for batch in train_loader:
        X = batch[0].to(device).float()
        y = batch[1].to(device).long()
        
        optimizer.zero_grad()
        outputs = model(X)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total_train += y.size(0)
        correct_train += (predicted == y).sum().item()
        
    avg_train_loss = train_loss / len(train_loader)
    train_acc = (correct_train / total_train) * 100

    # ==========================
    # 2. VALIDATION PHASE
    # ==========================
    model.eval() # Turn off dropout for accurate testing
    val_loss = 0
    correct_val = 0
    total_val = 0
    
    with torch.no_grad(): # Don't calculate gradients for validation (saves memory)
        for batch in test_loader:
            X = batch[0].to(device).float()
            y = batch[1].to(device).long()
            
            outputs = model(X)
            loss = criterion(outputs, y)
            
            val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_val += y.size(0)
            correct_val += (predicted == y).sum().item()
            
    avg_val_loss = val_loss / len(test_loader)
    val_acc = (correct_val / total_val) * 100

    print(f"Epoch {epoch+1}: Train Loss={avg_train_loss:.4f} (Acc={train_acc:.2f}%) | Val Loss={avg_val_loss:.4f} (Acc={val_acc:.2f}%)")
     # --- EARLY STOPPING ---
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_state = copy.deepcopy(model.state_dict())
        torch.save(model.state_dict(), 'best_googlenet15_final.pth')
        print(f"  --> New Best! {best_val_acc:.2f}%")
        counter = 0
    else:
        counter += 1
        if counter >= patience:
            print("  --> Early Stopping.")
            break

if best_model_state:
    model.load_state_dict(best_model_state)
    print(f"Finished. Best Acc: {best_val_acc:.2f}%")


Epoch 1: Train Loss=1.3635 (Acc=30.60%) | Val Loss=1.4152 (Acc=26.67%)
  --> New Best! 26.67%
Epoch 2: Train Loss=1.3239 (Acc=34.92%) | Val Loss=1.3664 (Acc=31.23%)
  --> New Best! 31.23%
Epoch 3: Train Loss=1.2981 (Acc=37.70%) | Val Loss=1.7322 (Acc=17.78%)
Epoch 4: Train Loss=1.2831 (Acc=38.77%) | Val Loss=1.3995 (Acc=23.10%)
Epoch 5: Train Loss=1.2660 (Acc=40.25%) | Val Loss=1.3509 (Acc=32.11%)
  --> New Best! 32.11%
Epoch 6: Train Loss=1.2547 (Acc=41.17%) | Val Loss=1.3761 (Acc=29.22%)
Epoch 7: Train Loss=1.2410 (Acc=42.29%) | Val Loss=1.2607 (Acc=40.98%)
  --> New Best! 40.98%
Epoch 8: Train Loss=1.2255 (Acc=43.11%) | Val Loss=1.3474 (Acc=32.95%)
Epoch 9: Train Loss=1.2186 (Acc=43.70%) | Val Loss=1.4561 (Acc=26.74%)
Epoch 10: Train Loss=1.2089 (Acc=44.60%) | Val Loss=1.4427 (Acc=30.00%)
Epoch 11: Train Loss=1.1982 (Acc=45.35%) | Val Loss=1.6006 (Acc=25.50%)
Epoch 12: Train Loss=1.1875 (Acc=46.16%) | Val Loss=1.4623 (Acc=34.57%)
Epoch 13: Train Loss=1.1787 (Acc=46.32%) | Val Loss=1