In [2]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
sns.set_theme(style="darkgrid")

# Transformer with Intent

In [None]:
from parksim.trajectory_predict.intent_transformer.network import TrajectoryPredictorWithIntent
from parksim.trajectory_predict.intent_transformer.dataset import IntentTransformerDataset

In [None]:
dataset_nums = ['../data/DJI_0022']
dataset = IntentTransformerDataset(
    dataset_nums, img_transform=transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=32, num_workers=12)

loss_fn = nn.L1Loss(reduction='none')


In [None]:
model_paths = {"150": "models/Intent_Transformer_03-19-2022_19-23-15.pth",
               "350": "models/Intent_Transformer_03-19-2022_19-47-58.pth",
               "500": "models/Intent_Transformer_03-19-2022_20-06-34.pth",
               "all(600)": "models/Intent_Transformer_all_data_03-19-2022_20-19-01.pth"}


In [None]:
def get_error_vs_time(model, dataloader, loss_fn, feature_size=3, steps=10):
    model.eval()

    pos_error = torch.empty(size=(feature_size, steps))
    ang_error = torch.empty(size=(feature_size, steps))

    with torch.no_grad():
        for batch in dataloader:
            img, X, intent, y_in, y_label = batch
            img = img.to(DEVICE).float()
            X = X.to(DEVICE).float()
            intent = intent.to(DEVICE).float()
            y_in = y_in.to(DEVICE).float()
            y_label = y_label.to(DEVICE).float()
            tgt_mask = model.transformer.generate_square_subsequent_mask(
                y_in.shape[1]).to(DEVICE).float()
            pred = model(img, X, intent, y_in, tgt_mask)
            loss = loss_fn(pred, y_label)

            pos_error = torch.cat(
                [pos_error, torch.sqrt(loss[:, :, 0]**2 + loss[:, :, 1]**2).detach()])
            ang_error = torch.cat([ang_error, loss[:, :, 2].detach()])

    return pos_error.numpy(), ang_error.numpy()


In [None]:
all_error = []
dt = 0.4

for name, path in model_paths.items():
    print(f'Getting statistics for model {name}')

    model = TrajectoryPredictorWithIntent()
    model_state = torch.load(path, map_location=DEVICE)
    model.load_state_dict(model_state)
    model.eval().to(DEVICE)

    pos_error, ang_error = get_error_vs_time(model, dataloader, loss_fn, steps=10)

    timesteps = dt*np.arange(1, pos_error.shape[1]+1)
    for i, time in enumerate(timesteps):
        for error in pos_error[:, i]:
            all_error.append([name, time, 'Positional', error])
        for error in ang_error[:, i]:
            all_error.append([name, time, 'Angular', error])

error_df = pd.DataFrame(
    all_error, columns=['Epoch', 'Timestep', 'Type', 'Error'])
error_df = error_df[abs(error_df['Error']) < 100] # Outlier removal
error_df['Epoch'] = pd.Categorical(error_df.Epoch)
error_df['Type'] = pd.Categorical(error_df.Type)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 7))

sns.lineplot(x="Timestep", y="Error", hue="Epoch", ci=95, style="Epoch",
             markers=True, dashes=False, data=error_df[error_df["Type"] == "Positional"], ax=axes[0])
sns.lineplot(x="Timestep", y="Error", hue="Epoch", ci=95, style="Epoch",
             markers=True, dashes=False, data=error_df[error_df["Type"] == "Angular"], ax=axes[1])

axes[0].set_title('Positional Error (m)')
axes[1].set_title('Angular Error (rad)')