In [None]:
#!unzip data.zip
#!unzip models.zip


In [None]:
import torch
import numpy
import torch.nn as nn
import math
import torch.nn.functional as F
from torch.utils.data.dataloader import default_collate

import data.data_generation
from data.data_generation import trun_poly_dataset
from models.modules import probeable_decoder_model
from data.data_utils import prepare_input, accuracy


In [None]:
#Load the training/validation split.
data = torch.load('data/datasets/first_400_primes_2000000data.pt')

train_data, val_data = torch.utils.data.random_split(data, [1600000, 400000])

torch.save(train_data, 'first_400_2000000train.pt')
torch.save(val_data, 'first_400_2000000val.pt')

device = 'cuda' if torch.cuda.is_available() else 'cpu'

train_loader = torch.utils.data.DataLoader(train_data, batch_size = 128, collate_fn=lambda x: tuple(x_.to(device) for x_ in default_collate(x)), shuffle = True)
val_loader = torch.utils.data.DataLoader(val_data, batch_size = 128, collate_fn=lambda x: tuple(x_.to(device) for x_ in default_collate(x)), shuffle = True)

In [None]:
#Model setup
n_heads = 8
emb_dim = 256
dim_feedforward = 1024
num_layers = 1
vocab_size = 5000
seq_len = data.seq_len

model = probeable_decoder_model(emb_dim, n_heads, vocab_size, seq_len, num_layers, dim_feedforward)
model_path = f'probeable_decoder_model_{n_heads}heads_{emb_dim}embdim_{dim_feedforward}mlp'
#model.load_state_dict(torch.load('1lyr_8hd_256embdim_epoch_30 (acc ~82)'))

#Training setup
epochs = 50
lr = .001
optim = torch.optim.Adam(model.parameters(), lr=lr)

In [None]:
#Training Loop

def train(model, epochs, optim, train_iterator, val_iterator):
    val_acc_list = []
    train_loss = []
    min_acc = .2 #initialize bound for saving best performing model based on validation accuracy.
    for i in range(epochs):
        #prep trackers for training

        model.train()
        tot_loss = 0
        for re, K_r in train_iterator:
            optim.zero_grad()
            model = model.to(device)

            #Add start token of zero to target sequence
            input = prepare_input(K_r)

            pred = model(re, input)  #size (batch, seq_len, num_classes)
            pred = torch.permute(pred, (0,2,1)) #prep for cross entropy loss

            loss = F.cross_entropy(pred, K_r, reduction = 'mean')
            loss.backward()
            optim.step()

            tot_loss += loss.detach().item()/160000
        train_loss.append(tot_loss)

        print("Epoch {} training loss = {:.6f}".format(i+1, tot_loss))
        print("Epoch {} training accuracy = {:.6f}".format(i+1, accuracy(model, train_iterator)))


        #prep trackers for validation
        model.eval()
        val_acc = accuracy(model, val_iterator)

        if (i+1)%10 == 0:
          torch.save(model.state_dict(), f'{model_path}_epoch{i+1}')

        if val_acc >= min_acc:
          min_acc = val_acc
          torch.save(model.state_dict(), "best_performing_model")


        val_acc_list.append(val_acc)

        print("Epoch {} validation accuracy = {:.6f}".format(i+1, val_acc))

    return train_loss, val_acc_list

In [None]:
train_loss, val_loss = train(model, epochs, optim, train_loader, val_loader)

Epoch 1 training loss = 0.014256
Epoch 1 training accuracy = 0.755042
Epoch 1 validation accuracy = 0.781245
Epoch 2 training loss = 0.006345
Epoch 2 training accuracy = 0.824248
Epoch 2 validation accuracy = 0.844537
Epoch 3 training loss = 0.004469
Epoch 3 training accuracy = 0.864689
Epoch 3 validation accuracy = 0.875095
Epoch 4 training loss = 0.003620
Epoch 4 training accuracy = 0.884354
Epoch 4 validation accuracy = 0.892655
Epoch 5 training loss = 0.003162
Epoch 5 training accuracy = 0.895814
Epoch 5 validation accuracy = 0.900300
Epoch 6 training loss = 0.002854
Epoch 6 training accuracy = 0.901369
Epoch 6 validation accuracy = 0.905265
Epoch 7 training loss = 0.002630
Epoch 7 training accuracy = 0.908942
Epoch 7 validation accuracy = 0.910705
Epoch 8 training loss = 0.002436
Epoch 8 training accuracy = 0.914504
Epoch 8 validation accuracy = 0.915142
Epoch 9 training loss = 0.002300
Epoch 9 training accuracy = 0.918181
Epoch 9 validation accuracy = 0.917595
Epoch 10 training l

In [None]:
torch.save(model.state_dict(), '1lyr_8hd_256embdim_1024mlp_epoch50')