# EEG Emotion Classification with Transformers (PyTorch)

In [None]:
!pip install torch numpy

In [None]:
import torch
import numpy as np
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset


In [None]:

class EEGDataset(Dataset):
    def __init__(self, data_path):
        data = np.load(data_path, allow_pickle=True).item()
        self.X = torch.tensor(data['X'], dtype=torch.float32)
        self.y = torch.tensor(data['y'], dtype=torch.long)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


In [None]:

class EEGTransformer(nn.Module):
    def __init__(self, input_dim, num_classes, n_heads=4, n_layers=2):
        super().__init__()
        self.embedding = nn.Linear(input_dim, 64)
        encoder_layer = nn.TransformerEncoderLayer(d_model=64, nhead=n_heads)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.classifier = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, num_classes)
        )

    def forward(self, x):
        x = self.embedding(x)
        x = self.transformer(x)
        x = x.mean(dim=1)
        return self.classifier(x)


In [None]:

# Generate fake EEG data
samples, time_steps, channels, classes = 500, 128, 32, 3
X = np.random.randn(samples, time_steps, channels)
y = np.random.randint(0, classes, size=(samples,))
np.save("demo_eeg.npy", {"X": X, "y": y})
print("Fake EEG dataset saved as demo_eeg.npy")


In [None]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_set = EEGDataset("demo_eeg.npy")
train_loader = DataLoader(train_set, batch_size=16, shuffle=True)

model = EEGTransformer(input_dim=32, num_classes=3).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(10):
    model.train()
    total_loss = 0
    for X, y in train_loader:
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")
