In [None]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np

from dataset import SynapseDataset
from model import HybridNeuroMagos

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
# config
bs = 128
lr = 1e-3
epochs = 30

dataset = SynapseDataset('Synapse_Dataset')

# split 80/20
train_len = int(0.8 * len(dataset))
val_len = len(dataset) - train_len
train_ds, val_ds = random_split(dataset, [train_len, val_len])

# num_workers=2 for colab
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True, num_workers=2)
val_dl = DataLoader(val_ds, batch_size=bs, shuffle=False)

len(train_ds), len(val_ds)

In [None]:
model = HybridNeuroMagos().to(device)
opt = optim.Adam(model.parameters(), lr=lr)
crit = nn.CrossEntropyLoss()

In [None]:
history = []
best_acc = 0

for epoch in range(epochs):
    model.train()
    losses = []
    
    # train
    pbar = tqdm(train_dl)
    for x, y in pbar:
        x, y = x.to(device), y.to(device)
        
        opt.zero_grad()
        out = model(x)
        loss = crit(out, y)
        loss.backward()
        opt.step()
        
        losses.append(loss.item())
        pbar.set_description(f"E {epoch+1} | Loss: {np.mean(losses):.4f}")
        
    # val
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in val_dl:
            x, y = x.to(device), y.to(device)
            out = model(x)
            pred = out.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
            
    acc = correct / total
    print(f"Val Acc: {acc:.4f}")
    
    if acc > best_acc:
        torch.save(model.state_dict(), 'best_model.pth')
        best_acc = acc
        print("Saved.")
        
    history.append(acc)

In [None]:
plt.plot(history)
plt.title('Val Acc')
plt.show()