In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

## Hyperparams

In [21]:
n_epochs = 100
batch_size = 256
n_features = 128
projection_dim = 128
weight_decay = 1e-4
lr = 1e-3

n_channels = 5
seq_len = 256
n_classes = 10

vis_freq = 10

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cpu


In [22]:
data_dir = "./data/"
# O1 O2 T5 T6
# bandpass 5-95 Hz

## Load Data

In [23]:
import datasets

d_train = datasets.load_from_disk(data_dir + "dataset_train_preprocessed")
d_test = datasets.load_from_disk(data_dir + "dataset_test_preprocessed")

In [24]:
train_eegs = torch.tensor(d_train["pixel_values"])
train_labels = torch.tensor(d_train["label"])

test_eegs = torch.tensor(d_test["pixel_values"])
test_labels = torch.tensor(d_test["label"])

In [25]:
print(train_eegs.shape, train_eegs.dtype)
print(train_labels.shape, train_labels.dtype)

torch.Size([10436, 256, 5]) torch.float32
torch.Size([10436]) torch.int64


In [26]:
from torch.utils.data import Dataset

class EEGDataset(Dataset):
    def __init__(self, eegs, labels, transform=None):
        self.eegs = eegs
        self.labels = labels
        self.transform = transform

    def __getitem__(self, idx):
        eeg = self.eegs[idx]
        label = self.labels[idx]

        if self.transform:
            eeg = self.transform(eeg)

        return eeg, label

    def __len__(self):
        return len(self.eegs)

train_data = EEGDataset(train_eegs, train_labels)
test_data = EEGDataset(test_eegs, test_labels)

In [27]:
eeg, label = train_data[0]
print(eeg.shape, label)

torch.Size([256, 5]) tensor(6)


## Define Network

In [28]:
from models import EEGFeatNet, ClassificationHead

In [29]:
eeg = torch.randn((batch_size, seq_len, n_channels)).to(device)
print(eeg.shape)
model = EEGFeatNet(n_channels=n_channels, n_features=n_features, projection_dim=projection_dim).to(device)

classifier = ClassificationHead(projection_dim, 10)

proj = model(eeg)
print(proj.shape)

torch.Size([256, 256, 5])
torch.Size([256, 128])


In [30]:
f"All Parameters: {sum(p.numel() for p in model.parameters())}"

'All Parameters: 85504'

## Dataloaders

In [31]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

In [32]:
i, (eeg, label) = next(enumerate(train_loader))

print(eeg.shape)

torch.Size([256, 256, 5])


In [33]:
proj = model(eeg)
print(proj.shape)

torch.Size([256, 128])


In [34]:
for eeg, label in train_loader:
    eeg = eeg.to(device)
    proj = model(eeg)
    print(proj.shape)
    break

torch.Size([256, 128])


## Train Loop

In [35]:
import torch

def train(epoch, model, optimizer, loss_fn, miner, train_dataloader, test_dataloader, accuracy_calculator):
    tq = tqdm(train_dataloader)
    for batch_idx, (eeg, label) in enumerate(tq, start=1):
        eeg    = eeg.to(device)
        label = label.to(device)

        optimizer.zero_grad()

        x_proj = model(eeg)

        hard_pairs = miner(x_proj, label)
        loss = loss_fn(x_proj, label, hard_pairs)
        
        loss.backward()
        optimizer.step()
                
        tq.set_description('Train:[{}, {:0.3f}]'.format(epoch, loss.item()))

    if (epoch % vis_freq) == 0:
        acc = calc_accuracy(model, train_dataloader, test_dataloader, accuracy_calculator)
        print("[Epoch: {}, Precision@1: {}]".format(epoch, acc))

def calc_accuracy(model, train_dataloader, test_dataloader, accuracy_calculator):
    X_embeds, Y = get_embeddings_over_dataset(model, train_dataloader)
    X_embeds_test, Y_test = get_embeddings_over_dataset(model, test_dataloader)
    accuracies = accuracy_calculator.get_accuracy(
        X_embeds_test, Y_test, X_embeds, Y, False
    )

    return accuracies["precision_at_1"]

def get_embeddings_over_dataset(
    model,
    loader
    ):
    """Loop through a full dataset and return all embeddings.
    """
    # Create a loader on the go

    X_embeds, Y = [], []
    for i, (x, y) in enumerate(tqdm(loader)):
        x_embeds = get_embeddings(x, model)
        X_embeds.append(x_embeds)
        Y.append(y)

    X_embeds = torch.cat(X_embeds, dim=0)
    Y = torch.cat(Y, dim=0)
    return X_embeds, Y
        
def get_embeddings(
    x: torch.Tensor, 
    model: nn.Module, 
    ) -> torch.Tensor:
    """Calculate embeddings for a batch of images.
    """
    #########################
    # Finish Your Code HERE
    # #########################

    x_embeds = model(x)
    
    #########################

    x_embeds = x_embeds.cpu()   # Cast to CPU
    x_embeds = torch.nn.functional.normalize(x_embeds, dim=1)      # Extra Step: Normalize the embeddings
    return x_embeds


In [36]:
from pytorch_metric_learning import miners, losses, distances
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator

model = EEGFeatNet(n_channels=n_channels, n_features=n_features, projection_dim=projection_dim).to(device)
optimizer = torch.optim.AdamW(list(model.parameters()), lr=lr, weight_decay=weight_decay)

margin = 0.2
distance = distances.LpDistance()
loss_fn = losses.TripletMarginLoss(margin, distance=distance)
miner = miners.TripletMarginMiner(margin, "semihard", distance=distance)

accuracy_calculator = AccuracyCalculator(include=("precision_at_1",), k=1)

In [37]:
if device == "cuda":
    torch.cuda.empty_cache()

In [38]:
for epoch in range(n_epochs):
    train(epoch, model, optimizer, loss_fn, miner, train_loader, test_loader, accuracy_calculator)

Train:[0, 0.000]: 100%|██████████| 41/41 [00:12<00:00,  3.28it/s]
100%|██████████| 41/41 [00:08<00:00,  5.07it/s]
100%|██████████| 11/11 [00:02<00:00,  5.41it/s]


[Epoch: 0, Precision@1: 0.09946442234123948]


Train:[1, 0.135]: 100%|██████████| 41/41 [00:12<00:00,  3.23it/s]
Train:[2, 0.000]: 100%|██████████| 41/41 [00:12<00:00,  3.16it/s]
Train:[3, 0.000]: 100%|██████████| 41/41 [00:12<00:00,  3.22it/s]
Train:[4, 0.000]: 100%|██████████| 41/41 [00:12<00:00,  3.20it/s]
Train:[5, 0.000]: 100%|██████████| 41/41 [00:12<00:00,  3.23it/s]
Train:[6, 0.000]: 100%|██████████| 41/41 [00:12<00:00,  3.22it/s]
Train:[7, 0.000]: 100%|██████████| 41/41 [00:12<00:00,  3.24it/s]
Train:[8, 0.000]: 100%|██████████| 41/41 [00:13<00:00,  3.08it/s]
Train:[9, 0.000]: 100%|██████████| 41/41 [00:13<00:00,  3.14it/s]
Train:[10, 0.000]: 100%|██████████| 41/41 [00:12<00:00,  3.16it/s]
100%|██████████| 41/41 [00:08<00:00,  5.12it/s]
100%|██████████| 11/11 [00:02<00:00,  5.33it/s]


[Epoch: 10, Precision@1: 0.1009946442234124]


Train:[11, 0.000]: 100%|██████████| 41/41 [00:12<00:00,  3.24it/s]
Train:[12, 0.000]: 100%|██████████| 41/41 [00:12<00:00,  3.23it/s]
Train:[13, 0.000]: 100%|██████████| 41/41 [00:12<00:00,  3.22it/s]
Train:[14, 0.000]: 100%|██████████| 41/41 [00:12<00:00,  3.24it/s]
Train:[15, 0.000]: 100%|██████████| 41/41 [00:12<00:00,  3.28it/s]
Train:[16, 0.000]: 100%|██████████| 41/41 [00:12<00:00,  3.27it/s]
Train:[17, 0.000]: 100%|██████████| 41/41 [00:12<00:00,  3.21it/s]
Train:[18, 0.000]: 100%|██████████| 41/41 [00:12<00:00,  3.25it/s]
Train:[19, 0.000]: 100%|██████████| 41/41 [00:12<00:00,  3.26it/s]
Train:[20, 0.000]: 100%|██████████| 41/41 [00:12<00:00,  3.24it/s]
100%|██████████| 41/41 [00:07<00:00,  5.21it/s]
100%|██████████| 11/11 [00:02<00:00,  5.31it/s]


[Epoch: 20, Precision@1: 0.09908186687069626]


Train:[21, 0.000]: 100%|██████████| 41/41 [00:12<00:00,  3.26it/s]
Train:[22, 0.000]: 100%|██████████| 41/41 [00:12<00:00,  3.26it/s]
Train:[23, 0.000]: 100%|██████████| 41/41 [00:12<00:00,  3.24it/s]
Train:[24, 0.000]: 100%|██████████| 41/41 [00:12<00:00,  3.27it/s]
Train:[25, 0.000]: 100%|██████████| 41/41 [00:12<00:00,  3.28it/s]
Train:[26, 0.000]: 100%|██████████| 41/41 [00:12<00:00,  3.25it/s]
Train:[27, 0.000]: 100%|██████████| 41/41 [00:12<00:00,  3.26it/s]
Train:[28, 0.000]: 100%|██████████| 41/41 [00:12<00:00,  3.26it/s]
Train:[29, 0.000]: 100%|██████████| 41/41 [00:13<00:00,  3.04it/s]
Train:[30, 0.000]: 100%|██████████| 41/41 [00:12<00:00,  3.19it/s]
100%|██████████| 41/41 [00:07<00:00,  5.18it/s]
100%|██████████| 11/11 [00:02<00:00,  5.37it/s]


[Epoch: 30, Precision@1: 0.09678653404743688]


Train:[31, 0.000]: 100%|██████████| 41/41 [00:12<00:00,  3.23it/s]
Train:[32, 0.000]:  24%|██▍       | 10/41 [00:03<00:10,  2.89it/s]


KeyboardInterrupt: 