# Inference and Evaluation of the MTP model

In [3]:
import os
import argparse
import pickle
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from torch_geometric.loader import DataLoader
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from tqdm import tqdm

from dataset import CarDataset
from utils.config import DT, OBS_LEN, PRED_LEN
from model import GNN_mtl_gnn, GNN_mtl_mlp


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
def rotation_matrix_back(yaw):
    rot = np.array([[np.cos(-np.pi/2 + yaw), -np.sin(-np.pi/2 + yaw)],
                    [np.sin(-np.pi/2 + yaw),  np.cos(-np.pi/2 + yaw)]],
                   dtype=np.float32)
    return torch.from_numpy(rot)

def run_inference(weights_path, dataloader, mlp=False, device=None):
    """Return a DataFrame with one row per agent-step: TIMESTAMP, TRACK_ID, X, Y, yaw, speed"""
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")

    model = (GNN_mtl_mlp(128) if mlp else GNN_mtl_gnn(128)).to(device)
    model.load_state_dict(torch.load(weights_path, map_location=device))
    model.eval()

    rows = []
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Inference"):
            batch = batch.to(device)
            out   = model(batch.x[:, [0, 1, 4, 5, 6]], batch.edge_index)
            out   = out.reshape(-1, 30, 2).permute(0, 2, 1)           # [N, 2, 30]

            yaw   = batch.x[:, 3].cpu().numpy()
            rots  = torch.stack([rotation_matrix_back(y) for y in yaw]).to(out.device)
            out   = torch.bmm(rots, out).permute(0, 2, 1)             # [N, 30, 2]
            out   = out + batch.x[:, [0, 1]].unsqueeze(1)             # global coords

            # Gather meta-data
            ts0 = np.zeros(out.shape[0], dtype=np.float32)            # starting time of each sample
            ids = np.arange(out.shape[0])                    # track IDs (0, 1, 2, ...) 
            speeds= batch.x[:, 2].cpu().numpy()
            yaws  = yaw

            for i in range(out.shape[0]):              # each agent
                # 30 future steps → timestamps
                fut_ts = ts0[i] + np.arange(1, 31) * DT
                xs, ys = out[i, :, 0].cpu(), out[i, :, 1].cpu()
                for step in range(30):
                    rows.append(dict(TIMESTAMP=fut_ts[step],
                                     TRACK_ID = int(ids[i]),
                                     X        = float(xs[step]),
                                     Y        = float(ys[step]),
                                     yaw      = float(yaws[step%len(yaws)]),
                                     speed    = float(speeds[i])))
    return pd.DataFrame(rows)

def visualise(df, bg_img_path, out_mp4="inference.mp4",
              interval_ms=100, max_frames=None):

    # Background image
    bg_img = plt.imread(bg_img_path)
    # Decide world-coordinate extent.  Here we assume 1 pixel = 1 unit.
    ypixels, xpixels = bg_img.shape[:2]
    extent = [0, xpixels, 0, ypixels]   # (xmin, xmax, ymin, ymax)

    # Prepare colour map
    track_ids = df['TRACK_ID'].unique()
    colours   = {tid: plt.cm.tab20(i % 20) for i, tid in enumerate(track_ids)}

    timestamps = np.sort(df['TIMESTAMP'].unique())
    if max_frames is not None:
        timestamps = timestamps[:max_frames]
    grouped = df.groupby('TIMESTAMP')

    fig, ax = plt.subplots(figsize=(8, 8))
    ax.imshow(bg_img, extent=extent, cmap='gray', origin='lower', zorder=0)

    def init():
        ax.clear()
        ax.imshow(bg_img, extent=extent, cmap='gray', origin='lower', zorder=0)
        ax.set_xlim(extent[0], extent[1])
        ax.set_ylim(extent[2], extent[3])
        ax.set_aspect('equal')
        return []

    def update(frame_idx):
        ts = timestamps[frame_idx]
        ax.clear()
        ax.imshow(bg_img, extent=extent, cmap='gray', origin='lower', zorder=0)
        ax.set_xlim(extent[0], extent[1])
        ax.set_ylim(extent[2], extent[3])
        ax.set_aspect('equal')
        ax.set_title(f"t = {ts:.2f}s") 
        ax.axis('off')

        if ts not in grouped.groups:
            return []

        frame_df = grouped.get_group(ts)
        for _, row in frame_df.iterrows():
            x, y   = float(row['X']), float(row['Y'])
            tid    = row['TRACK_ID']
            yaw_deg= float(row['yaw'])
            speed  = float(row['speed'])
            colour = colours.get(tid, 'black')

            # Rectangle (car)
            length, width = 4, 2
            dx, dy        = -length/2, -width/2
            rect = plt.Rectangle((x+dx, y+dy), length, width,
                                 color=colour, alpha=0.8, zorder=2)
            rot  = np.deg2rad(-(yaw_deg+90))          # negative for x-axis right-hand
            transf = (plt.matplotlib.transforms.Affine2D()
                      .rotate_around(x, y, rot) + ax.transData)
            rect.set_transform(transf)
            ax.add_patch(rect)

            # Velocity arrow
            arrow_len = min(speed*0.4, 10)
            vx = arrow_len * np.cos(np.deg2rad(yaw_deg+90))
            vy = arrow_len * np.sin(np.deg2rad(yaw_deg+90))
            ax.arrow(x, y, -vx, vy,
                     head_width=0.7, head_length=1.2,
                     fc='k', ec='k', zorder=3)

        return []

    ani = animation.FuncAnimation(fig, update, frames=len(timestamps),
                                  init_func=init, interval=interval_ms, blit=False)
    ani.save(out_mp4, writer='ffmpeg', fps=1000//interval_ms)
    print(f"Saved → {out_mp4}")

In [None]:
from datetime import datetime
batch_size = 8000
train_folder = "csv/train_pre_1k_simple_separate_10m"
val_folder = "csv/val_pre_1k_simple_separate_10m"
model_path = f"trained_params_archive/sumo_with_mpc_online_control/model_rot_gnn_mtl_np_sumo_0911_e3_1930.pth"
inter_map = "simple_separate_10m"
bg_img_path = f"map_binary_images/{inter_map}_binary.png"
out_mp4 = f"inference_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{inter_map}.mp4"

val_dataset = CarDataset(preprocess_folder=val_folder, mlp=False, mpc_aug=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=False)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

mlp = False
collision_penalty = False
model = GNN_mtl_mlp(hidden_channels=128).to(device) if mlp else GNN_mtl_gnn(hidden_channels=128)
print(model)

df_pred = run_inference(model_path, val_loader, mlp=mlp, device=device)
visualise(df_pred, bg_img_path=bg_img_path, out_mp4=out_mp4, 
           interval_ms=100, max_frames=None)

Using device: cuda
GNN_mtl_gnn(
  (conv1): GraphConv(128, 128)
  (conv2): GraphConv(128, 128)
  (linear1): Linear(in_features=5, out_features=64, bias=True)
  (linear2): Linear(in_features=64, out_features=128, bias=True)
  (linear3): Linear(in_features=128, out_features=128, bias=True)
  (linear4): Linear(in_features=128, out_features=128, bias=True)
  (linear5): Linear(in_features=128, out_features=60, bias=True)
)


Inference: 100%|██████████| 3/3 [00:16<00:00,  5.61s/it]
