# Train model

In [None]:
import math
import numpy as np
import os
from pathlib import Path

from lapsim.encoder.partition import Partition

from lapsim.normalisation import TransformNormalisation

BATCH_SIZE = 1024
EPOCHS = 50
FORESIGHT = 120
SAMPLING = 4
NORMALISATION_BOUNDS_PATH = "bounds.json"

DATA_PATH = Path(r"../dataset/encoded/train")

# This code currently assumes you have one validation partition. but should be easy to add more
VAL_DATA_PATH = Path(r"../dataset/encoded/val")

TRAINING_PARTITIONS = [x for x in os.listdir(DATA_PATH) if x[0] != '.']
VALIDATION_PARTITIONS = [x for x in os.listdir(VAL_DATA_PATH) if x[0] != '.']


In [None]:
bounds = TransformNormalisation()

if os.path.exists(NORMALISATION_BOUNDS_PATH):
    bounds = TransformNormalisation.load(NORMALISATION_BOUNDS_PATH)
    print("Existing bounds loaded.")

else:
    print("No bounds file found, calculating new ones.")
    partitions = os.listdir(DATA_PATH)

    for p in TRAINING_PARTITIONS:
        partition = Partition.load(DATA_PATH / p)
        bounds.extend(partition)

    bounds.save(NORMALISATION_BOUNDS_PATH)
    print("Finished calculating bounds.")

bounds.transform.foresight = FORESIGHT
bounds.transform.sampling = SAMPLING


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import HuberLoss
from torch.optim import NAdam


def hard_sigmoid(x):
    return torch.clamp((x + 2.5) / 5, min=0, max=1)


class LapSimModel(nn.Module):

    def __init__(self):
        super().__init__()

        self.d1 = nn.Linear(739, 450)
        self.d2 = nn.Linear(450, 200)
        self.d3 = nn.Linear(200, 200)
        self.d4 = nn.Linear(200, 9)
        self.d5 = nn.Linear(200, 9)

        self.loss = HuberLoss()
        self.optimiser = NAdam(self.parameters())

    def forward(self, windows, vehicles):
        x = torch.concatenate((vehicles, windows), axis=1)

        x = F.sigmoid(self.d1(x))
        x = F.sigmoid(self.d2(x))
        x = F.sigmoid(self.d3(x))

        pos = hard_sigmoid(self.d4(x))
        vel = hard_sigmoid(self.d5(x))

        return pos, vel

def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        return torch.device("mps")

    return torch.device("cpu")


def tensor(x):
    return torch.tensor(x, dtype=torch.float32).to(device)


device = get_device()

model = LapSimModel().to(device)

In [None]:
from toolkit.utils import Logger


def batchify(count, n):
    indexes = np.array(list(range(count)))
    np.random.shuffle(indexes)

    batches = []
    for i in range((count // n) + 1):
        batch = indexes[i * n:(i + 1)*n]
        if len(batch) > 0:
            batches.append(batch)

    return batches


best_loss = math.inf
best_model_perf = None

logger = Logger(labels=['Position Loss', 'Velocity Loss'], n_partitions=len(TRAINING_PARTITIONS))

x, (y_pos, y_vel), vehicles = bounds.normalise_and_transform(Partition.load(DATA_PATH / TRAINING_PARTITIONS[0]), cores=4)
val_x, (val_y_pos, val_y_vel), val_vehicles = bounds.normalise_and_transform(Partition.load(VAL_DATA_PATH / VALIDATION_PARTITIONS[0]), cores=4)


for epoch in range(EPOCHS):
    # Train model
    model.train()
    for i, partition in enumerate(TRAINING_PARTITIONS):
        # Start preloading the next batch of data while model trains
        loader = bounds.async_load_and_normalise_partition(
            DATA_PATH / partition,
            cores=4)

        batches = batchify(x.shape[0], BATCH_SIZE)
        for batch_idx, batch in enumerate(batches):
            model.optimiser.zero_grad()

            pred_pos, pred_vel = model(tensor(x[batch]), tensor(vehicles[batch]))

            model.optimiser.zero_grad()
            pos_loss = model.loss(pred_pos, tensor(y_pos[batch]))
            vel_loss = model.loss(pred_vel, tensor(y_vel[batch]))
            total_loss = (pos_loss + vel_loss) / 2
            total_loss.backward()
            model.optimiser.step()

            logger.write(epoch, batch=batch_idx, n_batches=len(batches), losses=[pos_loss.item(), vel_loss.item()])

        # Once preloaded has finished we can set it to the next x/y datums
        loader.join()
        x, (y_pos, y_vel), vehicles = loader.normalisation
            
    # Validation
    model.eval()
    val_losses = []
    with torch.no_grad():
        batches = batchify(val_x.shape[0], BATCH_SIZE)
        for batch_idx, batch in enumerate(batches):
            val_pred_pos, val_pred_vel = model(tensor(val_x[batch]), tensor(val_vehicles[batch]))
            pos_loss = model.loss(val_pred_pos, tensor(val_y_pos[batch]))
            vel_loss = model.loss(val_pred_vel, tensor(val_y_vel[batch]))
            val_losses.append((pos_loss + vel_loss).item())

            logger.write_val(epoch, batch=batch_idx, n_batches=len(batches), losses=[pos_loss.item(), vel_loss.item()])

    if best_model_perf is None or np.mean(val_losses) < best_model_perf:
        torch.save(model.state_dict(), "ls1.pt")
        best_model_perf = np.mean(val_losses)

    logger.flush(epoch)
