In [None]:
import os

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
from torch import nn
from torch.utils.data import DataLoader
from pathlib import Path
from tempfile import gettempdir

from l5kit.data import LocalDataManager, ChunkedDataset
from l5kit.dataset import AgentDataset, EgoDataset
from l5kit.rasterization import build_rasterizer
from l5kit.evaluation import write_pred_csv, compute_metrics_csv, read_gt_csv, create_chopped_dataset
from l5kit.evaluation.chop_dataset import MIN_FUTURE_STEPS
from l5kit.visualization import PREDICTED_POINTS_COLOR, TARGET_POINTS_COLOR, draw_trajectory
from l5kit.geometry import transform_points
from l5kit.evaluation.metrics import neg_multi_log_likelihood, time_displace, average_displacement_error_mean, rmse

from src.utils import load_config, get_model_class
# from src.dataset import load_datasets
from src.trainer import Trainer

In [None]:
# Configurations
CONFIG_PATH = "models/configs/vit_deit_config.yaml"
MODEL_NAME = "ViTDeitModel"  # name of model file in models/
EXP_NAME = "notebook_exp_vit_deit"
EXP_NUM = 0  # which epoch to load for evaluation/vis

# Load config
cfg = load_config(CONFIG_PATH)

In [None]:
# Paths
exp_dir = os.path.join("experiments", EXP_NAME)
os.makedirs(exp_dir, exist_ok=True)
ckpt_path = os.path.join(exp_dir, f"epoch_{EXP_NUM}.pth")
pred_path = os.path.join(exp_dir, f"predictions_{EXP_NUM}.csv")

In [None]:
# Load Datasets
os.environ["L5KIT_DATA_FOLDER"] = "../data/lyft-motion-prediction-autonomous-vehicles"
dm = LocalDataManager(None)
train_cfg = cfg["train_data_loader"]
rasterizer = build_rasterizer(cfg, dm)
train_zarr = ChunkedDataset(dm.require(train_cfg["key"])).open()
train_dataset = AgentDataset(cfg, train_zarr, rasterizer)
train_dataloader = DataLoader(train_dataset, shuffle=train_cfg["shuffle"], batch_size=train_cfg["batch_size"], 
                             num_workers=0)
print(train_dataset)

In [None]:
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of training batches: {len(train_dataloader)}")

In [None]:
eval_cfg = cfg["val_data_loader"]
rasterizer = build_rasterizer(cfg, dm)
eval_zarr = ChunkedDataset(dm.require(eval_cfg["key"])).open()
eval_dataset = AgentDataset(cfg, eval_zarr, rasterizer)
eval_dataloader = DataLoader(eval_dataset, shuffle=eval_cfg["shuffle"], batch_size=eval_cfg["batch_size"], 
                             num_workers=0)
print(eval_dataset)

In [None]:
# Create Model
ModelClass = get_model_class(MODEL_NAME)
model = ModelClass(cfg)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
print(model)
print(list(model.parameters()))
print(sum(p.numel() for p in model.parameters()))

In [None]:
# Train (only if checkpoint doesn't exist)
losses_train = []
if not os.path.exists(ckpt_path):
    print(f"Checkpoint not found at {ckpt_path}, training model...")
    trainer = Trainer(cfg, model, device, train_dataloader, exp_name=EXP_NAME)
    # losses_train = trainer.train(EXP_NUM)
    losses_train, losses_val_epoch, losses_train_epoch = trainer.train_and_validate(eval_dataloader, EXP_NUM)
else:
    print(f"Loading existing checkpoint: {ckpt_path}")
    model.load_state_dict(torch.load(ckpt_path))

In [None]:
if losses_train:
    plt.plot(np.arange(len(losses_train)), losses_train, label="train loss")
    plt.xlabel("Training Steps")
    plt.ylabel("Loss")
    plt.title("Training Loss Over Time")
    plt.legend()
    plt.show()

In [None]:
if losses_train_epoch:
    plt.plot(np.arange(len(losses_train_epoch)), losses_train_epoch, label="train loss (epoch)")
    plt.plot(np.arange(len(losses_train_epoch)), losses_val_epoch, label="val loss (epoch)")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title("Training Loss Over Epochs")
    plt.legend()
    plt.show()

In [None]:
# save losses as npy file
losses_train_path = os.path.join(exp_dir, f"losses_train_{EXP_NUM}.npy")
losses_train_epoch_path = os.path.join(exp_dir, f"losses_train_epoch_{EXP_NUM}.npy")
losses_val_epoch_path = os.path.join(exp_dir, f"losses_val_{EXP_NUM}.npy")

np.save(losses_train_path, losses_train)
np.save(losses_train_epoch_path, losses_train_epoch)
np.save(losses_val_epoch_path, losses_val_epoch)

In [None]:
# ===== GENERATE AND LOAD CHOPPED TEST DATASET
num_frames_to_chop = 100
test_cfg = cfg["test_data_loader"]
test_base_path = create_chopped_dataset(dm.require(test_cfg["key"]), cfg["raster_params"]["filter_agents_threshold"],
                              num_frames_to_chop, cfg["model_params"]["future_num_frames"], MIN_FUTURE_STEPS)

In [None]:
test_zarr_path = str(Path(test_base_path) / Path(dm.require(test_cfg["key"])).name)
test_mask_path = str(Path(test_base_path) / "mask.npz")
test_gt_path = str(Path(test_base_path) / "gt.csv")

test_zarr = ChunkedDataset(test_zarr_path).open()
test_mask = np.load(test_mask_path)["arr_0"]
# ===== INIT TEST DATASET AND LOAD MASK
test_dataset = AgentDataset(cfg, test_zarr, rasterizer, agents_mask=test_mask)
test_dataloader = DataLoader(test_dataset, shuffle=test_cfg["shuffle"], batch_size=test_cfg["batch_size"], 
                             num_workers=0)
print(test_dataset)

In [None]:
# Evaluate
# ==== EVAL LOOP
def evaluate_model(model, test_dataloader, device, pred_path):
    model.eval()
    criterion = nn.MSELoss(reduction="none")

    # store information for evaluation
    future_coords_offsets_pd = []
    timestamps = []
    agent_ids = []

    progress_bar = tqdm(test_dataloader)
    with torch.no_grad():
        for data in progress_bar:
            _, outputs = model.forward_pass(data, device, criterion)

            # convert agent coordinates into world offsets
            agents_coords = outputs.cpu().numpy()
            world_from_agents = data["world_from_agent"].numpy()
            centroids = data["centroid"].numpy()
            coords_offset = transform_points(agents_coords, world_from_agents) - centroids[:, None, :2]
            
            future_coords_offsets_pd.append(np.stack(coords_offset))
            timestamps.append(data["timestamp"].numpy().copy())
            agent_ids.append(data["track_id"].numpy().copy())
    
    write_pred_csv(pred_path,
               timestamps=np.concatenate(timestamps),
               track_ids=np.concatenate(agent_ids),
               coords=np.concatenate(future_coords_offsets_pd),
              )

In [None]:
# Evaluation (only if prediction doesn't exist)
if not os.path.exists(pred_path):
    print(f"Predictions not found at {pred_path}, evaluating model...")
    evaluate_model(model, test_dataloader, device, pred_path)
else:
    print(f"Loading existing predictions: {pred_path}")

In [None]:
metrics = compute_metrics_csv(test_gt_path, pred_path, [neg_multi_log_likelihood, time_displace, average_displacement_error_mean, rmse])
for metric_name, metric_mean in metrics.items():
    print(metric_name, metric_mean)

# save metrics as csv without pandas
with open(os.path.join(exp_dir, f"metrics_{EXP_NUM}.csv"), "w") as f:
    f.write("metric,mean\n")
    for metric_name, metric_mean in metrics.items():
        f.write(f"{metric_name},{metric_mean}\n")

In [None]:
model.eval()
torch.set_grad_enabled(False)

# build a dict to retrieve future trajectories from GT
gt_rows = {}
for row in read_gt_csv(test_gt_path):
    gt_rows[row["track_id"] + row["timestamp"]] = row["coord"]

test_ego_dataset = EgoDataset(cfg, test_dataset.dataset, rasterizer)

for frame_number in range(99, len(test_zarr.frames), 100):  # start from last frame of scene_0 and increase by 100
    agent_indices = test_dataset.get_frame_indices(frame_number) 
    if not len(agent_indices):
        continue

    # get AV point-of-view frame
    data_ego = test_ego_dataset[frame_number]
    im_ego = rasterizer.to_rgb(data_ego["image"].transpose(1, 2, 0))
    center = np.asarray(cfg["raster_params"]["ego_center"]) * cfg["raster_params"]["raster_size"]
    
    predicted_positions = []
    target_positions = []

    for v_index in agent_indices:
        print(f"Processing agent {v_index} in frame {frame_number}")
        data_agent = test_dataset[v_index]

        out_net = model(torch.from_numpy(data_agent["image"]).unsqueeze(0).to(device))
        out_pos = out_net[0].reshape(-1, 2).detach().cpu().numpy()
        # store absolute world coordinates
        predicted_positions.append(transform_points(out_pos, data_agent["world_from_agent"]))
        # retrieve target positions from the GT and store as absolute coordinates
        track_id, timestamp = data_agent["track_id"], data_agent["timestamp"]
        target_positions.append(gt_rows[str(track_id) + str(timestamp)] + data_agent["centroid"][:2])


    # convert coordinates to AV point-of-view so we can draw them
    predicted_positions = transform_points(np.concatenate(predicted_positions), data_ego["raster_from_world"])
    target_positions = transform_points(np.concatenate(target_positions), data_ego["raster_from_world"])

    draw_trajectory(im_ego, predicted_positions, PREDICTED_POINTS_COLOR)
    draw_trajectory(im_ego, target_positions, TARGET_POINTS_COLOR)

    plt.imshow(im_ego)
    plt.show()