In [None]:
from pathlib import Path
from src.datamodule.av2_dataset import Av2Dataset
from src.datamodule.av2_extractor_multiagent import Av2ExtractorMultiAgent
from src.datamodule.av2_dataset import collate_fn_cuda
import torch
from matplotlib import pyplot as plt
import numpy as np

data_root = Path("")

dataset = Av2Dataset(
    data_root=data_root,
    cached_split="trajectory-prediction/val",
    extractor=Av2ExtractorMultiAgent(radius=150),
)

In [None]:
from src.model.trainer_mrm import Trainer as Model

ckpt = ""
model = Model.load_from_checkpoint(ckpt, pretrain_weights=None)
model = model.eval()

In [None]:
data = dataset[0]

global_pos, predict_pos = model.predict(collate_fn_cuda([data]))

for i in range(global_pos.shape[1]):
    plt.plot(global_pos[0, i, :, 0], global_pos[0, i, :, 1], ".")

for i in range(predict_pos.shape[2]):
    plt.plot(predict_pos[0, 0, i, 0], predict_pos[0, 0, i, 1], "r.")

In [None]:
from src.model.trainer_mtm import Trainer as Model

ckpt = ""
model = Model.load_from_checkpoint(ckpt, pretrain_weights=None)
model = model.eval()

In [None]:
data = dataset[0]

global_pos, predict_pos = model.predict(collate_fn_cuda([data]))

for i in range(global_pos.shape[1]):
    plt.scatter(global_pos[0, i, :, 0], global_pos[0, i, :, 1], c="red", s=5)

for i in range(predict_pos.shape[2]):
    plt.plot(predict_pos[0, 0, i, 0], predict_pos[0, 0, i, 1], ".")

In [None]:
from src.model.trainer import Trainer as Model

ckpt = ""
model = Model()

state_dict = torch.load(ckpt)["state_dict"]
model.load_state_dict(state_dict=state_dict, strict=False)
model = model.eval().to("cuda")

In [None]:
def get_history_and_label(data):
    x_scored = data["x_scored"]
    history_pos = data["x_positions"]
    lane = data["lane_positions"]
    label_pos = data["y"]
    last_position = data["x_positions"][:, -1, :].view(-1, 1, 2)
    origin = data["origin"].view(1, 1, 2).double()
    theta = data["theta"].double()
    rotate_mat = torch.stack(
        [
            torch.cos(theta),
            torch.sin(theta),
            -torch.sin(theta),
            torch.cos(theta),
        ],
        dim=1,
    ).view(1, 2, 2)

    with torch.no_grad():
        history_pos = (
            (torch.matmul(history_pos[..., :2].double(), rotate_mat) + origin)
            .cpu()
            .numpy()
        )
        label_pos = (
            (
                torch.matmul(
                    label_pos[..., :2].double() + last_position,
                    rotate_mat,
                )
                + origin
            )
            .cpu()
            .numpy()
        )
        lane_pos = (
            (torch.matmul(lane[..., :2].double(), rotate_mat) + origin).cpu().numpy()
        )
    return history_pos[x_scored], label_pos[x_scored], lane_pos, x_scored

In [None]:
data = dataset[0]
history_pos, label_pos, lane_pos, x_scored = get_history_and_label(data)
predict_pos = model.predict(collate_fn_cuda([data])).squeeze(0)[x_scored]

for i in range(lane_pos.shape[0]):
    num_unique = np.unique(lane_pos[i, :, 0], return_index=True)[1].size
    num = 20 if num_unique == 20 else num_unique - 1
    plt.plot(lane_pos[i, :, 0][:num], lane_pos[i, :, 1][:num], color="grey", alpha=0.2)

for i in range(predict_pos.shape[0]):
    for j in range(predict_pos.shape[1]):
        plt.plot(predict_pos[i, j, :, 0], predict_pos[i, j, :, 1], "g-", alpha=0.5)

for i in range(history_pos.shape[0]):
    plt.plot(history_pos[i, :, 0], history_pos[i, :, 1], "b", alpha=0.5)

for i in range(label_pos.shape[0]):
    plt.scatter(label_pos[i, :, 0], label_pos[i, :, 1], c="red", s=5, alpha=0.5)