# Training Pipeline

# Imports
## Pip Packages

In [1]:
import os
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

if not os.path.exists('data'):
    new_directory_path = "..\\..\\"
    os.chdir(new_directory_path)

from src.datasets import FreeViewInMemory, seq2seq_jagged_collate_fn, seq2seq_padded_collate_fn
from src.model import PathModel
import numpy as np

# Training

In [2]:
# TODO Implement Strategy and Builder pattern
datasetv2 = FreeViewInMemory(sample_size= 13,log = True, start_index=2)
model = PathModel(input_dim = 3,
                  output_dim = 3,
                  n_encoder = 2,
                  n_decoder = 2,
                  model_dim = 256,
                  total_dim = 256,
                  n_heads = 4,
                  ff_dim = 512,
                  max_pos_enc=15,
                  max_pos_dec=26)

dataloader = DataLoader(datasetv2, batch_size=128, shuffle=True, num_workers=0, collate_fn= seq2seq_padded_collate_fn)
# for batch in tqdm(dataloader):
for batch in dataloader:
    x,x_mask,y, y_mask, fixation_len = batch
    break

# model = torch.compile(model)
reg_out,cls_out = model(x,y, x_mask, y_mask)    

Data loaded in memory
input mask
target mask


In [36]:
def compute_loss(reg_out,cls_out, y, attn_mask, fixation_len):
    criterion_reg = torch.nn.MSELoss()
    criterion_cls = torch.nn.BCEWithLogitsLoss()
    # the end token should not have a regression
    attn_mask_reg = attn_mask.clone()
    batch_idx = torch.arange(cls_out.size()[0])
    attn_mask_reg[batch_idx, fixation_len] = False

    # >>>>>> Classification loss
    # balance the classification loss
    weights = torch.ones(cls_out.size(), dtype = torch.float32)
    div = 1/fixation_len
    div = torch.repeat_interleave(div, repeats=fixation_len, dim=0).unsqueeze(-1)
    weights[attn_mask_reg] = div
    # the end token must be 1, because of the start token the number of fixations points to the end
    cls_targets = torch.zeros(cls_out.size(), dtype = torch.float32)
    cls_targets[batch_idx,fixation_len] = 1.0    
    cls_loss = criterion_cls(cls_out[attn_mask], cls_targets[attn_mask])
    
    # >>>>>> Regression loss
    # reshape the reg_mask
    attn_mask_reg = attn_mask_reg.unsqueeze(-1).expand(-1,-1,3)
    # reshape the attn_mask and remove the start token
    attn_mask = attn_mask.unsqueeze(-1).expand(-1,-1,3)
    attn_mask = attn_mask[:,1:,:]
    reg_loss = criterion_reg(reg_out[attn_mask_reg], y[attn_mask])
     # Example target: 1 if point exists, else 0
    return cls_loss, reg_loss

cls_loss, reg_loss = compute_loss(reg_out,cls_out, y, y_mask, fixation_len)

In [None]:
import torch.optim as optim

# Define the optimizer
optimizer = optim.Adam(model.parameters())
cls_weigth = .5
# Training loop (example with one epoch)
num_epochs = 1 # You can change this to train for more epochs
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
reg_loss_list = []
cls_loss_list = []
for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    total_reg_loss = 0
    total_cls_loss = 0
    num_batches = 0

    for batch in tqdm(dataloader):
        x,x_mask,y, y_mask, fixation_len = batch
        x = x.to(device=device)
        y = y.to(device=device)
        if x_mask is not None:
            x_mask = x_mask.to(device = device)
        if y_mask is not None:
            y_mask = y_mask.to(device = device)
        fixation_len = fixation_len.to(device = device)

        optimizer.zero_grad()  # Zero the gradients

        reg_out, cls_out = model(x, y)  # Forward pass

        reg_loss, cls_loss = compute_loss(reg_out,cls_out, y, y_mask, fixation_len) # Compute loss

        total_loss = (1-cls_weigth)*reg_loss + cls_weigth*cls_loss
        total_loss.backward()
        optimizer.step()

        total_reg_loss += reg_loss.item()
        total_cls_loss += cls_loss.item()
        num_batches += 1

    avg_reg_loss = total_reg_loss / num_batches
    avg_cls_loss = total_cls_loss / num_batches
    reg_loss_list.append(avg_reg_loss)
    cls_loss_list.append(avg_cls_loss)
    print(f"Epoch {epoch+1}/{num_epochs}, Avg Regression Loss: {avg_reg_loss:.4f}, Avg Classification Loss: {avg_cls_loss:.4f}")

print("Training finished!")

  0%|          | 0/274 [00:00<?, ?it/s]

input mask
target mask


  0%|          | 1/274 [00:00<01:17,  3.51it/s]

input mask
target mask


  1%|          | 2/274 [00:00<01:12,  3.73it/s]

input mask
target mask


  1%|          | 3/274 [00:00<01:09,  3.89it/s]

input mask
target mask


  1%|▏         | 4/274 [00:01<01:06,  4.06it/s]

input mask
target mask


  2%|▏         | 5/274 [00:01<01:07,  4.01it/s]

input mask
target mask


  2%|▏         | 6/274 [00:01<01:08,  3.92it/s]

input mask
target mask


  3%|▎         | 7/274 [00:01<01:05,  4.06it/s]

input mask
target mask


  3%|▎         | 8/274 [00:01<01:03,  4.19it/s]

input mask
target mask


  3%|▎         | 9/274 [00:02<01:04,  4.13it/s]

input mask
target mask


  4%|▎         | 10/274 [00:02<01:02,  4.20it/s]

input mask
target mask


  4%|▍         | 11/274 [00:02<01:03,  4.13it/s]

input mask
target mask


  4%|▍         | 12/274 [00:02<01:04,  4.08it/s]

input mask
target mask


  5%|▍         | 13/274 [00:03<01:02,  4.16it/s]

input mask
target mask


  5%|▌         | 14/274 [00:03<01:05,  3.97it/s]

input mask
target mask


  5%|▌         | 15/274 [00:03<01:01,  4.18it/s]

input mask
target mask


  6%|▌         | 16/274 [00:03<01:03,  4.07it/s]

input mask
target mask


  6%|▌         | 17/274 [00:04<01:01,  4.18it/s]

input mask
target mask


  7%|▋         | 18/274 [00:04<01:00,  4.21it/s]

input mask
target mask


  7%|▋         | 20/274 [00:04<00:57,  4.45it/s]

input mask
target mask
input mask
target mask


  8%|▊         | 21/274 [00:05<00:57,  4.41it/s]

input mask
target mask


  8%|▊         | 22/274 [00:05<00:56,  4.45it/s]

input mask
target mask


  8%|▊         | 23/274 [00:05<00:57,  4.36it/s]

input mask
target mask


  9%|▉         | 24/274 [00:05<00:58,  4.27it/s]

input mask
target mask


  9%|▉         | 25/274 [00:06<00:59,  4.17it/s]

input mask
target mask


  9%|▉         | 26/274 [00:06<00:59,  4.18it/s]

input mask
target mask


 10%|▉         | 27/274 [00:06<00:58,  4.19it/s]

input mask
target mask


 10%|█         | 28/274 [00:06<00:57,  4.28it/s]

input mask
target mask


 11%|█         | 29/274 [00:06<00:57,  4.23it/s]

input mask
target mask


 11%|█         | 30/274 [00:07<00:56,  4.28it/s]

input mask
target mask


 11%|█▏        | 31/274 [00:07<01:00,  4.05it/s]

input mask
target mask


 12%|█▏        | 32/274 [00:07<01:02,  3.87it/s]

input mask
target mask


 12%|█▏        | 33/274 [00:07<00:59,  4.04it/s]

input mask
target mask


 12%|█▏        | 34/274 [00:08<00:58,  4.14it/s]

input mask
target mask


 13%|█▎        | 35/274 [00:08<00:56,  4.22it/s]

input mask
target mask


 13%|█▎        | 36/274 [00:08<00:55,  4.32it/s]

input mask
target mask


 14%|█▎        | 37/274 [00:08<00:56,  4.22it/s]

input mask
target mask


 14%|█▍        | 38/274 [00:09<00:55,  4.28it/s]

input mask
target mask


 14%|█▍        | 39/274 [00:09<00:55,  4.23it/s]

input mask
target mask


 15%|█▍        | 40/274 [00:09<00:53,  4.34it/s]

input mask
target mask


 15%|█▍        | 41/274 [00:09<00:53,  4.33it/s]

input mask
target mask


 15%|█▌        | 42/274 [00:10<00:53,  4.37it/s]

input mask
target mask


 16%|█▌        | 43/274 [00:10<00:52,  4.38it/s]

input mask
target mask


 16%|█▌        | 44/274 [00:10<00:53,  4.32it/s]

input mask
target mask


 16%|█▋        | 45/274 [00:10<00:51,  4.41it/s]

input mask
target mask


 17%|█▋        | 46/274 [00:10<00:54,  4.21it/s]

input mask
target mask


 17%|█▋        | 47/274 [00:11<00:53,  4.23it/s]

input mask
target mask


 18%|█▊        | 49/274 [00:11<00:51,  4.34it/s]

input mask
target mask
input mask
target mask


 19%|█▊        | 51/274 [00:12<00:49,  4.51it/s]

input mask
target mask
input mask
target mask


 19%|█▉        | 52/274 [00:12<00:48,  4.56it/s]

input mask
target mask


 20%|█▉        | 54/274 [00:12<00:47,  4.59it/s]

input mask
target mask
input mask
target mask


 20%|██        | 55/274 [00:12<00:47,  4.58it/s]

input mask
target mask


 20%|██        | 56/274 [00:13<00:48,  4.47it/s]

input mask
target mask


 21%|██        | 57/274 [00:13<00:47,  4.56it/s]

input mask
target mask


 21%|██        | 58/274 [00:13<00:46,  4.64it/s]

input mask
target mask


 22%|██▏       | 59/274 [00:13<00:47,  4.52it/s]

input mask
target mask


 22%|██▏       | 60/274 [00:14<00:46,  4.56it/s]

input mask
target mask


 22%|██▏       | 61/274 [00:14<00:46,  4.59it/s]

input mask
target mask


 23%|██▎       | 62/274 [00:14<00:46,  4.60it/s]

input mask
target mask


 23%|██▎       | 63/274 [00:14<00:47,  4.45it/s]

input mask
target mask


 23%|██▎       | 64/274 [00:15<00:49,  4.22it/s]

input mask
target mask


 24%|██▎       | 65/274 [00:15<00:50,  4.18it/s]

input mask
target mask


 24%|██▍       | 66/274 [00:15<00:49,  4.19it/s]

input mask
target mask


 24%|██▍       | 67/274 [00:15<00:47,  4.32it/s]

input mask
target mask


 25%|██▌       | 69/274 [00:16<00:46,  4.43it/s]

input mask
target mask
input mask
target mask


 26%|██▌       | 70/274 [00:16<00:45,  4.45it/s]

input mask
target mask


 26%|██▌       | 71/274 [00:16<00:46,  4.36it/s]

input mask
target mask


 26%|██▋       | 72/274 [00:16<00:47,  4.22it/s]

input mask
target mask


 27%|██▋       | 73/274 [00:17<00:49,  4.10it/s]

input mask
target mask


 27%|██▋       | 74/274 [00:17<00:48,  4.09it/s]

input mask
target mask


 27%|██▋       | 75/274 [00:17<00:48,  4.14it/s]

input mask
target mask


 28%|██▊       | 76/274 [00:17<00:50,  3.93it/s]

input mask
target mask


 28%|██▊       | 77/274 [00:18<00:49,  3.99it/s]

input mask
target mask


 28%|██▊       | 78/274 [00:18<00:48,  4.06it/s]

input mask
target mask


 29%|██▉       | 79/274 [00:18<00:49,  3.95it/s]

input mask
target mask


 29%|██▉       | 80/274 [00:18<00:48,  3.99it/s]

input mask
target mask


 30%|██▉       | 81/274 [00:19<00:47,  4.07it/s]

input mask
target mask


 30%|██▉       | 82/274 [00:19<00:46,  4.10it/s]

input mask
target mask


 30%|███       | 83/274 [00:19<00:47,  4.06it/s]

input mask
target mask


 31%|███       | 84/274 [00:19<00:48,  3.90it/s]

input mask
target mask


 31%|███       | 85/274 [00:20<00:46,  4.05it/s]

input mask
target mask


 31%|███▏      | 86/274 [00:20<00:47,  3.98it/s]

input mask
target mask


 32%|███▏      | 87/274 [00:20<00:46,  4.05it/s]

input mask
target mask


 32%|███▏      | 88/274 [00:20<00:46,  3.99it/s]

input mask
target mask


 32%|███▏      | 89/274 [00:21<00:45,  4.06it/s]

input mask
target mask


 33%|███▎      | 90/274 [00:21<00:45,  4.03it/s]

input mask
target mask


 33%|███▎      | 91/274 [00:21<00:44,  4.10it/s]

input mask
target mask


 34%|███▎      | 92/274 [00:21<00:45,  4.00it/s]

input mask
target mask


 34%|███▍      | 93/274 [00:22<00:44,  4.09it/s]

input mask
target mask


 34%|███▍      | 94/274 [00:22<00:43,  4.16it/s]

input mask
target mask


 35%|███▍      | 95/274 [00:22<00:43,  4.07it/s]

input mask
target mask


 35%|███▌      | 96/274 [00:22<00:44,  4.03it/s]

input mask
target mask


 35%|███▌      | 97/274 [00:23<00:42,  4.12it/s]

input mask
target mask


 36%|███▌      | 98/274 [00:23<00:43,  4.01it/s]

input mask
target mask


 36%|███▌      | 99/274 [00:23<00:44,  3.92it/s]

input mask
target mask


 36%|███▋      | 100/274 [00:23<00:43,  4.03it/s]

input mask
target mask


 37%|███▋      | 101/274 [00:24<00:43,  3.99it/s]

input mask
target mask


 37%|███▋      | 102/274 [00:24<00:43,  3.96it/s]

input mask
target mask


 38%|███▊      | 103/274 [00:24<00:42,  4.00it/s]

input mask
target mask


 38%|███▊      | 104/274 [00:24<00:42,  4.02it/s]

input mask
target mask


 38%|███▊      | 105/274 [00:25<00:42,  4.00it/s]

input mask
target mask


 39%|███▊      | 106/274 [00:25<00:44,  3.76it/s]

input mask
target mask


 39%|███▉      | 107/274 [00:25<00:43,  3.83it/s]

input mask
target mask


 39%|███▉      | 108/274 [00:25<00:44,  3.76it/s]

input mask
target mask


 40%|███▉      | 109/274 [00:26<00:45,  3.66it/s]

input mask
target mask


 40%|████      | 110/274 [00:26<00:43,  3.75it/s]

input mask
target mask


 41%|████      | 111/274 [00:26<00:43,  3.77it/s]

input mask
target mask


 41%|████      | 112/274 [00:26<00:43,  3.72it/s]

input mask
target mask


 41%|████      | 113/274 [00:27<00:41,  3.86it/s]

input mask
target mask


 42%|████▏     | 114/274 [00:27<00:40,  3.94it/s]

input mask
target mask


 42%|████▏     | 115/274 [00:27<00:40,  3.92it/s]

input mask
target mask


 42%|████▏     | 116/274 [00:27<00:40,  3.92it/s]

input mask
target mask


 43%|████▎     | 117/274 [00:28<00:39,  3.98it/s]

input mask
target mask


 43%|████▎     | 118/274 [00:28<00:40,  3.84it/s]

input mask
target mask


 43%|████▎     | 119/274 [00:28<00:39,  3.88it/s]

input mask
target mask


 44%|████▍     | 120/274 [00:29<00:39,  3.94it/s]

input mask
target mask


 44%|████▍     | 121/274 [00:29<00:39,  3.91it/s]

input mask
target mask


 45%|████▍     | 122/274 [00:29<00:38,  3.93it/s]

input mask
target mask


 45%|████▍     | 123/274 [00:29<00:38,  3.93it/s]

input mask
target mask


 45%|████▌     | 124/274 [00:30<00:37,  3.98it/s]

input mask
target mask


 46%|████▌     | 125/274 [00:30<00:36,  4.08it/s]

input mask
target mask


 46%|████▌     | 126/274 [00:30<00:37,  3.93it/s]

input mask
target mask


 46%|████▋     | 127/274 [00:30<00:36,  4.02it/s]

input mask
target mask


 47%|████▋     | 128/274 [00:30<00:35,  4.11it/s]

input mask
target mask


 47%|████▋     | 129/274 [00:31<00:35,  4.11it/s]

input mask
target mask


 47%|████▋     | 130/274 [00:31<00:35,  4.07it/s]

input mask
target mask


 48%|████▊     | 131/274 [00:31<00:34,  4.13it/s]

input mask
target mask


 48%|████▊     | 132/274 [00:31<00:35,  3.95it/s]

input mask
target mask


 49%|████▊     | 133/274 [00:32<00:37,  3.77it/s]

input mask
target mask


 49%|████▉     | 134/274 [00:32<00:36,  3.83it/s]

input mask
target mask


 49%|████▉     | 135/274 [00:32<00:36,  3.78it/s]

input mask
target mask


 50%|████▉     | 136/274 [00:33<00:35,  3.87it/s]

input mask
target mask


 50%|█████     | 137/274 [00:33<00:35,  3.90it/s]

input mask
target mask


 50%|█████     | 138/274 [00:33<00:34,  3.92it/s]

input mask
target mask


 51%|█████     | 139/274 [00:33<00:35,  3.85it/s]

input mask
target mask


 51%|█████     | 140/274 [00:34<00:35,  3.76it/s]

input mask
target mask


 51%|█████▏    | 141/274 [00:34<00:34,  3.82it/s]

input mask
target mask


 52%|█████▏    | 142/274 [00:34<00:34,  3.88it/s]

input mask
target mask


 52%|█████▏    | 143/274 [00:34<00:33,  3.90it/s]

input mask
target mask


 53%|█████▎    | 144/274 [00:35<00:32,  3.95it/s]

input mask
target mask


 53%|█████▎    | 145/274 [00:35<00:32,  4.00it/s]

input mask
target mask


 53%|█████▎    | 146/274 [00:35<00:33,  3.88it/s]

input mask
target mask


 54%|█████▎    | 147/274 [00:35<00:33,  3.80it/s]

input mask
target mask


 54%|█████▍    | 148/274 [00:36<00:32,  3.88it/s]

input mask
target mask


 54%|█████▍    | 149/274 [00:36<00:31,  4.03it/s]

input mask
target mask


 55%|█████▍    | 150/274 [00:36<00:30,  4.03it/s]

input mask
target mask


 55%|█████▌    | 151/274 [00:36<00:29,  4.15it/s]

input mask
target mask


 55%|█████▌    | 152/274 [00:37<00:29,  4.13it/s]

input mask
target mask


 56%|█████▌    | 153/274 [00:37<00:29,  4.10it/s]

input mask
target mask


 56%|█████▌    | 154/274 [00:37<00:28,  4.22it/s]

input mask
target mask


 57%|█████▋    | 155/274 [00:37<00:28,  4.12it/s]

input mask
target mask


 57%|█████▋    | 156/274 [00:38<00:28,  4.14it/s]

input mask
target mask


 57%|█████▋    | 157/274 [00:38<00:28,  4.14it/s]

input mask
target mask


 58%|█████▊    | 158/274 [00:38<00:28,  4.01it/s]

input mask
target mask


 58%|█████▊    | 159/274 [00:38<00:28,  3.99it/s]

input mask
target mask


 58%|█████▊    | 160/274 [00:39<00:29,  3.92it/s]

input mask
target mask


 59%|█████▉    | 161/274 [00:39<00:28,  3.98it/s]

input mask
target mask


 59%|█████▉    | 162/274 [00:39<00:27,  4.03it/s]

input mask
target mask


 59%|█████▉    | 163/274 [00:39<00:28,  3.90it/s]

input mask
target mask


 60%|█████▉    | 164/274 [00:40<00:30,  3.63it/s]

input mask
target mask


 60%|██████    | 165/274 [00:40<00:29,  3.68it/s]

input mask
target mask


 61%|██████    | 166/274 [00:40<00:28,  3.79it/s]

input mask
target mask


 61%|██████    | 167/274 [00:40<00:28,  3.71it/s]

input mask
target mask


 61%|██████▏   | 168/274 [00:41<00:27,  3.83it/s]

input mask
target mask


 62%|██████▏   | 169/274 [00:41<00:25,  4.07it/s]


input mask
target mask


KeyboardInterrupt: 