# Documentation
Example code for training transformer on Road Traffic Fine Management dataset.

# Install and import packages

In [None]:
import tqdm
import numpy as np
import random
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import csv
import os

In [None]:
from create_model import Transformer
from train_evaluate import train, validate, EarlyStopper, init_weights_kaiming

In [None]:
print("GPU available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU name:", torch.cuda.get_device_name(0))
    print("Number of GPUs:", torch.cuda.device_count())
print("Number of CPUs", os.cpu_count())

In [None]:
# set random seed
seed = 7
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)           # Ensures reproducibility on the CPU
torch.cuda.manual_seed_all(seed)  # Ensures reproducibility on all GPUs

# Define parameters

In [None]:
# define file path
train_trace_act_tensor_path = '.../train_trace_act.pt'
train_trace_time_tensor_path = '.../train_trace_time.pt'

val_trace_act_tensor_path = '.../val_trace_act.pt'
val_trace_time_tensor_path = '...val_trace_time.pt'

In [None]:
# define prefix length
prefix_len = 6
num_act = 13
num_time_features = 2

In [None]:
# define model design hyperparameters
d_embed = 4
d_model = 16
num_heads = 4
d_ff = d_model * 2
num_layers = 2
dropout = 0.1

In [None]:
# define model training hyperparameters
batch_size = 64

lr = 0.0003

num_epochs = 200

loss_mode = 'base'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load tensors

## Train dataloader

- `train_trace_act_tensor` has the shape `(num_samples, prefix_len + 1)`.  
- It contains the full activity label trace (from the first event to the EOC token), represented by indices and right-padded with zeros.
- Input and target sequences are derived from this tensor as follows:
    - Input sequence: obtained by removing the last element of the trace and replacing any EOC token (index 3) with 0 (padding).
    - Target sequence: obtained by removing the first element of the trace.
    - For example, given a trace [4, 5, 7, 8, 3, 0, 0, 0], the input sequence becomes [4, 5, 7, 8, 0, 0, 0], and the target sequence becomes [5, 7, 8, 3, 0, 0, 0].

In [None]:
train_trace_act_tensor = torch.load(train_trace_act_tensor_path)

# prepare input sequence
train_prefix_act = train_trace_act_tensor[:, :-1].clone()
train_prefix_act[train_prefix_act == 3] = 0

# prepare target sequence
train_tgt = train_trace_act_tensor[:, 1:].clone()

In [None]:
train_trace_time_tensor = torch.load(train_trace_time_tensor_path)

# prepare input sequence
train_prefix_time = train_trace_time_tensor[:, :-1]

In [None]:
train_dataset = TensorDataset(train_prefix_act, 
                              train_prefix_time, 
                              train_tgt)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

## Validation dataloader

In [None]:
val_trace_act_tensor = torch.load(val_trace_act_tensor_path)

# prepare input sequence
val_prefix_act = val_trace_act_tensor[:, :-1].clone()
val_prefix_act[val_prefix_act == 3] = 0

# prepare target sequence
val_tgt = val_trace_act_tensor[:, 1:].clone()

In [None]:
val_trace_time_tensor = torch.load(val_trace_time_tensor_path)

# prepare input sequence
val_prefix_time = val_trace_time_tensor[:, :-1]

In [None]:
val_dataset = TensorDataset(val_prefix_act, 
                            val_prefix_time, 
                            val_tgt)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Define trial function

In [None]:
def trial(model,
          lr,
          model_state_path,
          best_val_loss=float("inf"),
          early_stopper=None):

    results = []

    optimizer = optim.AdamW(model.parameters(), lr=lr)

    if early_stopper is None:
        early_stopper = EarlyStopper(patience=10, delta=0.001)

    for epoch in tqdm.tqdm(range(num_epochs)):

        train_loss = train(model,
                            train_dataloader,
                            optimizer,
                            device,
                            loss_mode)
        
        val_loss, accuracy, precision_macro, recall_macro, f1_macro = validate(model,
                                        val_dataloader,
                                        device,
                                        num_act,
                                        loss_mode)
        
        print(f"\tTrain Loss: {train_loss:7.3f} | Val Loss: {val_loss:7.3f} | Val Accuracy: {accuracy:7.3f}")
        print(f"\tVal Precision: {precision_macro:7.3f}| Val Recall: {recall_macro:7.3f}| Val macro F1: {f1_macro:7.3f}")

        if val_loss < (best_val_loss - 0.001):
            best_val_loss = val_loss
            torch.save(model.state_dict(), model_state_path)

        # Store metrics in the results list as a dictionary
        results.append({
            'epoch': epoch + 1,
            'train_loss': train_loss,
            'val_loss': val_loss,
            'val_accuracy': accuracy,
            'val_precision': precision_macro,
            'val_recall': recall_macro,
            'val_f1_score': f1_macro
            })
        
        # early stopping
        if early_stopper.early_stop(val_loss):     
            print(f"Early stopping triggered at epoch {epoch + 1}")        
            break

    return results

# Trial

In [None]:
# instantiate transformer
model = Transformer(prefix_len, 
                 num_act, num_time_features, d_embed, 
                 d_model, num_heads, d_ff, dropout,
                 num_layers).to(device)

# apply weight initialization
model.apply(init_weights_kaiming)

In [None]:
results = trial(model, lr,
            ".../experiment1_1_parameters.pt")

In [None]:
# After the loop, save the results to a CSV file
csv_file = '.../experiment1_1_loss.csv'
csv_columns = ['epoch', 'train_loss', 'val_loss', 'val_accuracy', 'val_precision', 'val_recall', 'val_f1_score']

try:
    with open(csv_file, mode='w', newline='') as file:
        writer = csv.DictWriter(file, fieldnames=csv_columns)
        writer.writeheader()
        writer.writerows(results)
    print(f"Metrics saved to {csv_file}")
except IOError as e:
    print("I/O error", e)