TO-DO:
- add metrics to the evaluation loop (f1, accuracy)
- add a readme explaining the choice of model based on the original paper
- add multi-gpu training support (DDP)

In [1]:
import torch
from vivit import ViViT
from data import Dataset
import wandb

LR = 0.001
EPOCHS = 10
SAVE_EVERY = 5
SAVE_PATH = './checkpoints/'
TRAINING_BATCH = 1024
VALIDATION_BATCH = 1024

model = ViViT(
    image_size = 224,          # image size
    frames = 60,               # number of frames
    image_patch_size = 16,     # image patch size
    frame_patch_size = 2,      # frame patch size
    num_classes = 60,          # predictions made per input
    # dim = 1024,                # dim of each patch's embedding
    dim = 512,                # dim of each patch's embedding
    # spatial_depth = 6,         # depth of the spatial transformer
    # temporal_depth = 6,        # depth of the temporal transformer
    spatial_depth = 3,         # depth of the spatial transformer
    temporal_depth = 3,        # depth of the temporal transformer
    heads = 8,                 # number of heads in multi-headed attention
    # mlp_dim = 2048,            # hidden dim of tranformer's feed forward layer
    mlp_dim = 512,            # hidden dim of tranformer's feed forward layer
    pool = 'mean'              # aggregate embeddings using mean or use CLS token's embedding
)

dataloader = torch.utils.data.DataLoader(
    Dataset(), # using training data
    batch_size = TRAINING_BATCH,
    shuffle = True
)

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

device = 'cuda'

def save(model, optimizer, path):
    torch.save({
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict()
    }, path)

def load(model, optimizer, path):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    return model, optimizer


In [None]:
# training
model.train()
model = model.to(device)

wandb.init(project='seizure_detection', name='training_run')

for epoch in range(1, EPOCHS+1):
    for index, batch in enumerate(dataloader):
        video, target = batch
        video, target = video.to(device), target.to(device)
        prediction = model(video)
        loss = loss_fn(prediction, target)
        wandb.log({
            'loss': loss.item(),
            'epoch': epoch
            }, commit = True)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    if epoch % SAVE_EVERY == 0:
        save(model, optimizer, SAVE_PATH)

wandb.finish()

In [None]:
# validation
model.eval()
model.to(device)

val_dataloader = torch.utils.data.DataLoader(
    Dataset(), # using validation data
    batch_size = VALIDATION_BATCH,
    shuffle = True
)

wandb.init(project='seizure_detection', name='validation_run')

with torch.no_grad():
    for index, batch in enumerate(dataloader):
        video, target = batch
        video, target = video.to(device), target.to(device)
        prediction = model(video)
        loss = loss_fn(prediction, target)
        wandb.log({
            'loss': loss.item(),
            'epoch': epoch
        }, commit = True)

wandb.finish()