# Training Pipeline

# Imports
## Pip Packages

In [1]:
import os
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

if not os.path.exists('data'):
    new_directory_path = "..\\..\\"
    os.chdir(new_directory_path)

from src.datasets import FreeViewInMemory, seq2seq_jagged_collate_fn, seq2seq_padded_collate_fn
from src.model import PathModel
import numpy as np

# Training

In [2]:
# TODO Implement Strategy and Builder pattern
datasetv2 = FreeViewInMemory(sample_size= 13,log = True, start_index=2)
model = PathModel(input_dim = 3,
                  output_dim = 3,
                  n_encoder = 2,
                  n_decoder = 2,
                  model_dim = 256,
                  total_dim = 256,
                  n_heads = 4,
                  ff_dim = 512,
                  max_pos_enc=15,
                  max_pos_dec=26)

dataloader = DataLoader(datasetv2, batch_size=128, shuffle=True, num_workers=0, collate_fn= seq2seq_padded_collate_fn)

Data loaded in memory


In [3]:
def compute_loss(reg_out,cls_out, y, attn_mask, fixation_len):
    criterion_reg = torch.nn.MSELoss()
    criterion_cls = torch.nn.BCEWithLogitsLoss()
    # the end token should not have a regression
    attn_mask_reg = attn_mask.clone()
    batch_idx = torch.arange(cls_out.size()[0])
    attn_mask_reg[batch_idx, fixation_len] = False

    # >>>>>> Classification loss
    # balance the classification loss
    weights = torch.ones(cls_out.size(), dtype = torch.float32)
    div = 1/fixation_len
    div = torch.repeat_interleave(div, repeats=fixation_len, dim=0).unsqueeze(-1)
    weights[attn_mask_reg] = div
    # the end token must be 1, because of the start token the number of fixations points to the end
    cls_targets = torch.zeros(cls_out.size(), dtype = torch.float32)
    cls_targets[batch_idx,fixation_len] = 1.0    
    cls_loss = criterion_cls(cls_out[attn_mask], cls_targets[attn_mask])
    
    # >>>>>> Regression loss
    # reshape the reg_mask
    attn_mask_reg = attn_mask_reg.unsqueeze(-1).expand(-1,-1,3)
    # reshape the attn_mask and remove the start token
    attn_mask = attn_mask.unsqueeze(-1).expand(-1,-1,3)
    attn_mask = attn_mask[:,1:,:]
    reg_loss = criterion_reg(reg_out[attn_mask_reg], y[attn_mask])
     # Example target: 1 if point exists, else 0
    return cls_loss, reg_loss

In [None]:
import torch.optim as optim

# Define the optimizer
optimizer = optim.Adam(model.parameters())
cls_weigth = .5
# Training loop (example with one epoch)
num_epochs = 1 # You can change this to train for more epochs
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
reg_loss_list = []
cls_loss_list = []
for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    total_reg_loss = 0
    total_cls_loss = 0
    num_batches = 0

    for batch in tqdm(dataloader):
        x,x_mask,y, y_mask, fixation_len = batch
        x = x.to(device=device)
        y = y.to(device=device)
        if x_mask is not None:
            x_mask = x_mask.to(device = device)
        if y_mask is not None:
            y_mask = y_mask.to(device = device)
        fixation_len = fixation_len.to(device = device)

        optimizer.zero_grad()  # Zero the gradients

        reg_out, cls_out = model(x,y, src_mask = x_mask, tgt_mask = y_mask)  # Forward pass
        print(cls_out.shape)
        print(y_mask.shape)
        reg_loss, cls_loss = compute_loss(reg_out,cls_out, y, y_mask, fixation_len) # Compute loss
        print(cls_out.shape)
        print(y_mask.shape)
        total_loss = (1-cls_weigth)*reg_loss + cls_weigth*cls_loss
        total_loss.backward()
        optimizer.step()

        total_reg_loss += reg_loss.item()
        total_cls_loss += cls_loss.item()
        num_batches += 1

    avg_reg_loss = total_reg_loss / num_batches
    avg_cls_loss = total_cls_loss / num_batches
    reg_loss_list.append(avg_reg_loss)
    cls_loss_list.append(avg_cls_loss)
    print(f"Epoch {epoch+1}/{num_epochs}, Avg Regression Loss: {avg_reg_loss:.4f}, Avg Classification Loss: {avg_cls_loss:.4f}")

print("Training finished!")

In [12]:
print(max(fixation_len))
print(cls_out.shape)
print(y.shape)

tensor(12)
torch.Size([128, 14, 1])
torch.Size([128, 12, 3])


In [23]:
def create_cls_targets(cls_out, fixation_len):
    batch_idx = torch.arange(cls_out.size()[0])
    cls_targets = torch.zeros(cls_out.size(), dtype = torch.float32)
    cls_targets[batch_idx,fixation_len] = 1.0
    return cls_targets

def accuracy(cls_out, attn_mask, cls_targets):
    cls_preds = torch.sigmoid(cls_out) >= 0.5
    attn_mask = attn_mask.unsqueeze(-1)
    correct = (cls_preds == cls_targets) & attn_mask
    accuracy = correct.sum().item() / attn_mask.sum().item()
    return accuracy

def precision(cls_out, attn_mask, cls_targets, cls = 1):
    cls_preds = torch.sigmoid(cls_out) >= 0.5
    attn_mask = attn_mask.unsqueeze(-1)
    true_positives = ((cls_preds == cls) & (cls_targets == cls) & attn_mask).sum().item()
    predicted_positives = ((cls_preds == cls) & attn_mask).sum().item()
    precision = true_positives / predicted_positives if predicted_positives > 0 else 0.0
    return precision

def recall(cls_out, attn_mask, cls_targets, cls = 1):
    cls_preds = torch.sigmoid(cls_out) >= 0.5
    attn_mask = attn_mask.unsqueeze(-1)
    true_positives = ((cls_preds == cls) & (cls_targets == cls) & attn_mask).sum().item()
    actual_positives = ((cls_targets == cls) & attn_mask).sum().item()
    recall = true_positives / actual_positives if actual_positives > 0 else 0.0
    return recall
cls_targets = create_cls_targets(cls_out, fixation_len)
print(accuracy(cls_out, y_mask, cls_targets))
print(precision(cls_out, y_mask, cls_targets))
print(recall(cls_out, y_mask, cls_targets))
print(precision(cls_out, y_mask, cls_targets, cls = 0))
print(recall(cls_out, y_mask, cls_targets, cls = 0))

0.8766859344894027
0.0
0.0
0.8766859344894027
1.0
