In [1]:
import argparse
import os
os.umask(0)
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
import pickle
import sys
from importlib import import_module

import torch
from torch.utils.data import DataLoader, Sampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm

from data import ArgoTestDataset
from utils import Logger, load_pretrain

In [3]:
torch.cuda.set_device(2)

In [4]:
from lanegcn_old import get_model
config, Dataset, collate_fn, net, loss, post_process, opt = get_model()


In [5]:
ckpt_path = 'results/lanegcn_old/100.000.ckpt'
ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
load_pretrain(net, ckpt["state_dict"])

In [6]:
net.eval()
net.cuda()

Net(
  (actor_net): ActorNet(
    (groups): ModuleList(
      (0): Sequential(
        (0): Res1d(
          (conv1): Conv1d(3, 32, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
          (conv2): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
          (relu): ReLU(inplace=True)
          (bn1): GroupNorm(1, 32, eps=1e-05, affine=True)
          (bn2): GroupNorm(1, 32, eps=1e-05, affine=True)
          (downsample): Sequential(
            (0): Conv1d(3, 32, kernel_size=(1,), stride=(1,), bias=False)
            (1): GroupNorm(1, 32, eps=1e-05, affine=True)
          )
        )
        (1): Res1d(
          (conv1): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
          (conv2): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
          (relu): ReLU(inplace=True)
          (bn1): GroupNorm(1, 32, eps=1e-05, affine=True)
          (bn2): GroupNorm(1, 32, eps=1e-05, affine=True)
        )
      )
    

In [7]:
dataset = ArgoTestDataset('val', config, train=False)

In [8]:
data_loader = DataLoader(
    dataset,
    batch_size=config["val_batch_size"],
    num_workers=config["val_workers"],
    collate_fn=collate_fn,
    shuffle=True,
    pin_memory=True,
)

In [10]:
preds = {}
gts = {}
cities = {}

for ii, data in tqdm(enumerate(data_loader)):
    data = dict(data)
    with torch.no_grad():
        output = net(data)
        results = [x[0:1].detach().cpu().numpy() for x in output["reg"]]
    for i, (argo_idx, pred_traj) in enumerate(zip(data["argo_id"], results)):
        preds[argo_idx] = pred_traj.squeeze()
        cities[argo_idx] = data["city"][i]
        gts[argo_idx] = data["gt_preds"][i][0] if "gt_preds" in data else None

# save for further visualization
res = dict(
    preds = preds,
    gts = gts,
    cities = cities,
)     # res['preds'][9].shape = (6, 30, 2), res['gts']={9:None,30:None...}, res['cities]=(9:'MIA',...,77:'PIT')


torch.save(res,f"{config['save_dir']}/test_lanegcn_JAN9.pkl")
 

15it [00:04,  3.17it/s]


In [11]:
from argoverse.evaluation.eval_forecasting import (compute_forecasting_metrics,)
_ = compute_forecasting_metrics(preds, gts, cities, 6, 30, 2)
_ = compute_forecasting_metrics(preds, gts, cities, 1, 30, 2)

------------------------------------------------
Prediction Horizon : 30, Max #guesses (K): 6
------------------------------------------------
{'minADE': 15.28020518365768, 'minFDE': 28.137537428379574, 'MR': 0.9958333333333333, 'DAC': 0.6774305555555554}
------------------------------------------------
------------------------------------------------
Prediction Horizon : 30, Max #guesses (K): 1
------------------------------------------------
{'minADE': 15.296492703145079, 'minFDE': 29.210810044875974, 'MR': 1.0, 'DAC': 0.6354166666666666}
------------------------------------------------
