# Train loop and auxilary functions for NSD project

In [1]:
%load_ext autoreload
%autoreload 2

import os
import torch
import torch.nn.functional as F
from dataset_pytorch import NSDDataset
from net import Net
from loss import ContrastiveLoss, TripletLoss
from visualization import plot_clusters

### Hyperparameters and network parameters

In [2]:
n_epochs = 10
n_iters = int(1e6)
current_epoch = 0
batch_size = 5
learning_rate = 1e-4
num_classes = 11

# Network parameters
training_type = 'clustering'
channels = [72, 128, 128, 128, 2] # Channels to be used in each ResBlock
reductions = ['mean', 'max', 'max', None] # Reduction to be used in ResBlock ('max', 'mean', None)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

### Train setup

In [3]:
dataset = NSDDataset('saved_data.pickle') # WARNING - dataset takes a lot of RAM (around 6 GB), loading onto insufficient device could cause crashing

model = Net(channels=channels, reductions=reductions, num_classes=num_classes, output_dimension=2, train_type=training_type).to(device)

if training_type == 'clustering':
    loss_function = torch.nn.TripletMarginLoss()
elif training_type == 'classification':
    loss_function = torch.nn.CrossEntropyLoss()

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.5)

In [None]:
print(model)

### Helper function for results validation

In [None]:
# TODO: Function to calculate accuracy on the dataset

### Train loop - classification

In [None]:
assert training_type == 'classification'
i = 0
while current_epoch <= n_epochs:
    # Load data
    targets, inputs = dataset.get_batch(batch_size=batch_size)
    i += 1
    targets = targets.to(device)
    inputs = inputs.to(device)

    # Reset gradient after each iteration
    optimizer.zero_grad()

    # Forward pass
    net_output = model(inputs)
    
    # Backward pass
    loss = loss_function(net_output, targets)
    loss.backward()
    torch.nn.utils.clip_grad.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    
    # Print during epoch (sanity check)
    with torch.no_grad():
        print(f"\rEpoch: {current_epoch + 1} Iter: [{i} / {int(len(dataset) / batch_size) + 1}]| Loss: {round(loss.item(), 3)}", end='')
    
    # Print after each epoch to get current results
    if current_epoch != dataset.epoch:
        current_epoch = dataset.epoch
        torch.save(model.state_dict(), f'models/model_{current_epoch}.pth')
        i = 0



## Training loop - clustering

In [None]:
print(model)

In [4]:
assert training_type == 'clustering'

print_frequency = 200
batch_size = 3


for i in range(n_iters):
    optimizer.zero_grad()
    l_all = torch.zeros((batch_size)).to(device)
    for b in range(batch_size):
        a, p, n = dataset.get_triplet()
        data_input = torch.stack([a, p, n]).to(device)
        data_output = F.sigmoid(model(data_input))
        af, pf, nf = torch.split(data_output, split_size_or_sections=1, dim=0)

        loss = loss_function(af, pf, nf)
        l_all[b] = loss

        torch.nn.utils.clip_grad.clip_grad_norm_(model.parameters(), max_norm=1.0)

    l_all = l_all.sum()
    l_all.backward()
    optimizer.step()

    with torch.no_grad():
        print(loss.item())

    if i % print_frequency == 0:
        model.eval()
        with torch.no_grad():
            outputs = torch.zeros((2, len(dataset))).to(device)

1.0435736179351807
0.999917209148407
1.0000665187835693
0.9997407793998718
0.999994158744812
1.0002386569976807
1.0001848936080933


KeyboardInterrupt: 