# Training Pipeline

# Imports
## Pip Packages

In [None]:
import os
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

if not os.path.exists('data'):
    new_directory_path = "..\\..\\"
    os.chdir(new_directory_path)

from src.datasets import FreeViewInMemory, seq2seq_jagged_collate_fn, seq2seq_padded_collate_fn
from src.model import PathModel
from src.cls_metrics import precision,recall,create_cls_targets, accuracy, compute_loss
import numpy as np

# Training

## Preparing Data

In [2]:
from torch.utils.data import random_split
# TODO Separate validation and test sets
datasetv2 = FreeViewInMemory(sample_size= 13,log = True, start_index=2)
total_size = len(datasetv2)
train_size = int(0.8 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size
train_set, val_set, test_set = random_split(datasetv2, [train_size, val_size, test_size])

Data loaded in memory


In [None]:
# TODO Implement Strategy and Builder pattern

model = PathModel(input_dim = 3,
                  output_dim = 3,
                  n_encoder = 2,
                  n_decoder = 2,
                  model_dim = 256,
                  total_dim = 256,
                  n_heads = 4,
                  ff_dim = 512,
                  max_pos_enc=15,
                  max_pos_dec=26)

dataloader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=0, collate_fn= seq2seq_padded_collate_fn)

Data loaded in memory


In [None]:
import torch.optim as optim

# Define the optimizer
optimizer = optim.Adam(model.parameters())
cls_weigth = .5
# Training loop (example with one epoch)
num_epochs = 1 # You can change this to train for more epochs
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
reg_loss_list = []
cls_loss_list = []
for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    total_reg_loss = 0
    total_cls_loss = 0
    num_batches = 0

    for batch in tqdm(dataloader):
        x,x_mask,y, y_mask, fixation_len = batch
        x = x.to(device=device)
        y = y.to(device=device)
        if x_mask is not None:
            x_mask = x_mask.to(device = device)
        if y_mask is not None:
            y_mask = y_mask.to(device = device)
        fixation_len = fixation_len.to(device = device)

        optimizer.zero_grad()  # Zero the gradients

        reg_out, cls_out = model(x,y, src_mask = x_mask, tgt_mask = y_mask)  # Forward pass
        reg_loss, cls_loss = compute_loss(reg_out,cls_out, y, y_mask, fixation_len) # Compute loss
        total_loss = (1-cls_weigth)*reg_loss + cls_weigth*cls_loss
        total_loss.backward()
        optimizer.step()

        total_reg_loss += reg_loss.item()
        total_cls_loss += cls_loss.item()
        num_batches += 1

    avg_reg_loss = total_reg_loss / num_batches
    avg_cls_loss = total_cls_loss / num_batches
    reg_loss_list.append(avg_reg_loss)
    cls_loss_list.append(avg_cls_loss)
    print(f"Epoch {epoch+1}/{num_epochs}, Avg Regression Loss: {avg_reg_loss:.4f}, Avg Classification Loss: {avg_cls_loss:.4f}")

print("Training finished!")

In [15]:

cls_targets = create_cls_targets(cls_out, fixation_len)
print(accuracy(cls_out, y_mask, cls_targets))
print(precision(cls_out, y_mask, cls_targets))
print(recall(cls_out, y_mask, cls_targets))
print(precision(cls_out, y_mask, cls_targets, cls = 0))
print(recall(cls_out, y_mask, cls_targets, cls = 0))

0.8739495798319328
0.3157894736842105
0.046875
0.8840304182509505
0.9862142099681867


In [7]:
# load model
data = torch.load('trained_model.pth', map_location = 'cpu')


In [9]:
# remove the _orig_mod prefix from the state dict keys
state_dict = data['model_state_dict']

new_state_dict = {}
for k, v in state_dict.items():
    # Check if the key starts with the problematic prefix
    if k.startswith('_orig_mod.'):
        # Remove the prefix
        new_key = k[len('_orig_mod.'):]
        new_state_dict[new_key] = v
    else:
        # Keep other keys as they are
        new_state_dict[k] = v

model.load_state_dict(new_state_dict)

<All keys matched successfully>

In [10]:
with torch.no_grad():
    reg_out,cls_out = model(x,y, x_mask, y_mask)
    cls_targets = create_cls_targets(cls_out, fixation_len)
    print(accuracy(cls_out, y_mask, cls_targets))
    print(precision(cls_out, y_mask, cls_targets))
    print(recall(cls_out, y_mask, cls_targets))
    print(precision(cls_out, y_mask, cls_targets, cls = 0))
    print(recall(cls_out, y_mask, cls_targets, cls = 0))
   

0.8739495798319328
0.3157894736842105
0.046875
0.8840304182509505
0.9862142099681867


In [11]:
for i in range(y.shape[0]):
    paired = torch.stack((cls_out[i],y_mask[i].unsqueeze(-1)),dim = -1)
    print(paired)
    print('---')

tensor([[[-6.8844,  1.0000]],

        [[-3.3327,  1.0000]],

        [[-3.0625,  1.0000]],

        [[-2.9924,  1.0000]],

        [[-3.3855,  1.0000]],

        [[-1.7940,  1.0000]],

        [[-2.2510,  1.0000]],

        [[-2.9439,  0.0000]],

        [[-2.8354,  0.0000]],

        [[-3.0042,  0.0000]],

        [[-2.4354,  0.0000]],

        [[-2.8288,  0.0000]],

        [[-2.9764,  0.0000]],

        [[-2.4336,  0.0000]],

        [[-2.8274,  0.0000]],

        [[-2.7162,  0.0000]],

        [[-2.5281,  0.0000]]])
---
tensor([[[-7.1378,  1.0000]],

        [[-2.8036,  1.0000]],

        [[-2.6193,  1.0000]],

        [[-1.8819,  1.0000]],

        [[-1.9031,  1.0000]],

        [[-1.0296,  1.0000]],

        [[-0.3728,  1.0000]],

        [[-1.8460,  1.0000]],

        [[-2.4529,  0.0000]],

        [[-2.5738,  0.0000]],

        [[-2.4708,  0.0000]],

        [[-2.7568,  0.0000]],

        [[-2.9144,  0.0000]],

        [[-2.5655,  0.0000]],

        [[-2.9675,  0.0000]],

    

In [12]:
for i in range(y.shape[0]):
    print(x[i])
    print(y[i])
    print(reg_out[i,:-1, :])
    print('---')

tensor([[ 248.3200,   84.8000,    0.0000],
        [ 350.7200,  321.6000,  200.0000],
        [ 335.3600,  302.4000,  400.0000],
        [ 248.3200,  238.4000,  600.0000],
        [ 279.0400,  321.6000,  800.0000],
        [ 360.9600,  321.6000, 1000.0000],
        [ 442.8800,  321.6000, 1200.0000],
        [ 407.0400,  321.6000, 1400.0000],
        [ 391.6800,  273.6000, 1600.0000],
        [ 227.8400,  148.8000, 1800.0001],
        [  58.8800,  321.6000, 2000.0001],
        [ 376.3200,  296.0000, 2200.0000],
        [ 386.5600,  321.6000, 2400.0000]])
tensor([[310.9486, 248.2286, 296.0000],
        [266.9105, 268.5257, 504.0000],
        [428.2819, 302.1105, 415.0000],
        [392.7467, 266.6362, 297.0000],
        [150.7048, 200.2590, 354.0000],
        [354.4076, 214.9181,  35.0000],
        [  0.0000,   0.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000],
