# Train loop and auxilary functions for NSD project

In [1]:
%load_ext autoreload
%autoreload 2

import os
import torch
from dataset_pytorch import NSDDataset
from net import Net
from loss import ContrastiveLoss

### Hyperparameters and network parameters

In [2]:
n_epochs = 10
current_epoch = 0
batch_size = 5
learning_rate = 5e-3
num_classes = 11

# Network parameters
channels = [72, 128, 128, 2] # Channels to be used in each ResBlock
reductions = ['mean', 'mean', None] # Reduction to be used in ResBlock ('max', 'mean', None)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

### Train setup

In [3]:
dataset = NSDDataset('saved_data.pickle') # WARNING - dataset takes a lot of RAM (around 6 GB), loading onto insufficient device could cause crashing
training_type = 'clustering'

model = Net(channels=channels, reductions=reductions, num_classes=num_classes, train_type=training_type).to(device)

if training_type == 'clustering':
    loss_function = ContrastiveLoss(radius=2.0)
elif training_type == 'classification':
    loss_function = torch.nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

### Helper function for results validation

In [None]:
# TODO: Function to calculate accuracy on the dataset

### Train loop

In [5]:
i = 0
while current_epoch <= n_epochs:
    # Load data
    targets, inputs = dataset.get_batch(batch_size=batch_size)
    i += 1
    targets = targets.to(device)
    inputs = inputs.to(device)

    # Reset gradient after each iteration
    optimizer.zero_grad()

    # Forward pass
    net_output = model(inputs)
    
    # Backward pass
    loss = loss_function(net_output, targets)
    loss.backward()
    torch.nn.utils.clip_grad.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    
    # Print during epoch (sanity check)
    with torch.no_grad():
        print(f"\rEpoch: {current_epoch + 1} Iter: [{i} / {int(len(dataset) / batch_size) + 1}]| Loss: {round(loss.item(), 3)}", end='')
    
    # Print after each epoch to get current results
    if current_epoch != dataset.epoch:
        current_epoch = dataset.epoch
        torch.save(model.state_dict(), f'models/model_{current_epoch}.pth')
        i = 0



Epoch: 1 Iter: [57 / 119]| Loss: 0.924