In [3]:
from importlib import import_module
import pickle
import sys
from data import ArgoTestDataset
import os
from utils import Logger, load_pretrain

import torch
from torch.utils.data import DataLoader, Sampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm

In [14]:
from collections import defaultdict
from typing import Dict, List, Optional

import matplotlib.lines as mlines
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.interpolate as interp

from argoverse.map_representation.map_api import ArgoverseMap

_ZORDER = {"AGENT": 15, "AV": 10, "OTHERS": 5}


def interpolate_polyline(polyline: np.ndarray, num_points: int) -> np.ndarray:
    duplicates = []
    for i in range(1, len(polyline)):
        if np.allclose(polyline[i], polyline[i - 1]):
            duplicates.append(i)
    if polyline.shape[0] - len(duplicates) < 4:
        return polyline
    if duplicates:
        polyline = np.delete(polyline, duplicates, axis=0)
    tck, u = interp.splprep(polyline.T, s=0)
    u = np.linspace(0.0, 1.0, num_points)
    return np.column_stack(interp.splev(u, tck))


def viz_sequence(
    df: pd.DataFrame,
    lane_centerlines: Optional[List[np.ndarray]] = None,
    show: bool = True,
    smoothen: bool = False,
    ax=None
) -> None:

    # Seq data
    city_name = df["CITY_NAME"].values[0]

    if lane_centerlines is None:
        # Get API for Argo Dataset map
        avm = ArgoverseMap()
        seq_lane_props = avm.city_lane_centerlines_dict[city_name]

    

    x_min = min(df["X"])
    x_max = max(df["X"])
    y_min = min(df["Y"])
    y_max = max(df["Y"])

    if lane_centerlines is None:

        ax.axis(xmin=x_min,xmax=x_max,ymin=y_min,ymax=y_max)

        lane_centerlines = []
        # Get lane centerlines which lie within the range of trajectories
        for lane_id, lane_props in seq_lane_props.items():

            lane_cl = lane_props.centerline

            if (
                np.min(lane_cl[:, 0]) < x_max
                and np.min(lane_cl[:, 1]) < y_max
                and np.max(lane_cl[:, 0]) > x_min
                and np.max(lane_cl[:, 1]) > y_min
            ):
                lane_centerlines.append(lane_cl)

    for lane_cl in lane_centerlines:
        ax.plot(
            lane_cl[:, 0],
            lane_cl[:, 1],
            "--",
            color="grey",
            alpha=1,
            linewidth=1,
            zorder=0,
        )
    frames = df.groupby("TRACK_ID")

    ax.set_xlabel("Map X")
    ax.set_ylabel("Map Y")

    color_dict = {"AGENT": "#d33e4c", "OTHERS": "#d3e8ef", "AV": "#007672"}
    object_type_tracker: Dict[int, int] = defaultdict(int)

    # Plot all the tracks up till current frame
    for group_name, group_data in frames:
        object_type = group_data["OBJECT_TYPE"].values[0]

        cor_x = group_data["X"].values
        cor_y = group_data["Y"].values

        if smoothen:
            polyline = np.column_stack((cor_x, cor_y))
            num_points = cor_x.shape[0] * 3
            smooth_polyline = interpolate_polyline(polyline, num_points)
            cor_x = smooth_polyline[:, 0]
            cor_y = smooth_polyline[:, 1]

        ax.plot(
            cor_x,
            cor_y,
            "-",
            color=color_dict[object_type],
            label=object_type if not object_type_tracker[object_type] else "",
            alpha=1,
            linewidth=1,
            zorder=_ZORDER[object_type],
        )

        final_x = cor_x[-1]
        final_y = cor_y[-1]

        if object_type == "AGENT":
            marker_type = "o"
            marker_size = 7
        elif object_type == "OTHERS":
            marker_type = "o"
            marker_size = 7
        elif object_type == "AV":
            marker_type = "o"
            marker_size = 7

        ax.plot(
            final_x,
            final_y,
            marker_type,
            color=color_dict[object_type],
            label=object_type if not object_type_tracker[object_type] else "",
            alpha=1,
            markersize=marker_size,
            zorder=_ZORDER[object_type],
        )

        object_type_tracker[object_type] += 1

    red_star = mlines.Line2D([], [], color="red", marker="*", linestyle="None", markersize=7, label="Agent")
    green_circle = mlines.Line2D(
        [],
        [],
        color="green",
        marker="o",
        linestyle="None",
        markersize=7,
        label="Others",
    )
    black_triangle = mlines.Line2D([], [], color="black", marker="^", linestyle="None", markersize=7, label="AV")

    ax.grid()
    #ax.axis("off")
    return ax

def fde(trajs,gt_traj):
  trajs=np.array(trajs)
  fdes=trajs[:,-1]-gt_traj[-1]
  return np.min(np.abs(fdes))


In [53]:
model = import_module("lanegcn")
config, _, collate_fn, net, loss, post_process, opt = model.get_model()

ckpt_path = "results/lanegcnWed Jul 20 17:34:12 2022/36.000.ckpt"
ckpt=torch.load(ckpt_path)
    
from collections import OrderedDict
new_state_dict = OrderedDict()  
for k, v in ckpt["state_dict"].items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
net.load_state_dict(new_state_dict,strict=True)

<All keys matched successfully>

In [5]:



dataset = ArgoTestDataset("test", config, train=False)
test_loader = DataLoader(
    dataset,
    batch_size=128,#config["val_batch_size"],
    num_workers=config["val_workers"],
    collate_fn=collate_fn,
    shuffle=True,
    pin_memory=True,
)



In [None]:
net.cuda()
net.eval()

In [7]:
from argoverse.data_loading.argoverse_forecasting_loader import ArgoverseForecastingLoader
root_dir = "/mnt/lustre/tangxiaqiang/Code/LaneGCN/dataset/test_obs/data"
afl = ArgoverseForecastingLoader(root_dir)

In [37]:
batch=next(iter(test_loader))

In [57]:
output = net(batch)

trajs_list = [x[0:1].detach().cpu().numpy() for x in output["reg"]]
probs_list=[x[0:1].detach().cpu().numpy() for x in output["cls"]]

In [None]:


for trajs,probs,key in zip(trajs_list,probs_list,batch["argo_id"]):
    
    seq_path = f"{root_dir}/"+str(key)+".csv"
    
    fig,ax = plt.subplots(figsize=(16, 14),dpi=100)
    ax=viz_sequence(afl.get(seq_path).seq_df,ax=ax)
    
    for traj,prob in zip(trajs.squeeze(),probs.squeeze()):
        ax.plot(traj[:,0],traj[:,1])
        ax.text(traj[-1,0],traj[-1,1],str(round(prob,2)))
    
    # plt.savefig("pic1/"+str(key)+".jpg")
    

In [46]:
len(batch["argo_id"])

128

In [43]:
len(trajs_list)

128

In [59]:
probs_list[0].shape

(1, 6)

In [60]:
sum(probs_list[0])


array([-2.9008076, -2.9183996, -2.9690464, -3.0066352, -3.1625133,
       -3.2031138], dtype=float32)

In [61]:
output.keys()

dict_keys(['cls', 'reg'])