In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import os
import sys
import numpy as np
import pandas as ps

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler


from l5kit.geometry import transform_points
from l5kit.data import LocalDataManager, ChunkedDataset
from l5kit.dataset import AgentDataset, EgoDataset
from l5kit.rasterization import build_rasterizer
from l5kit.evaluation import write_pred_csv

sys.path.append("..")

from src.batteries import t2d, load_checkpoint
from src.batteries.progress import tqdm
from src.models.genet import genet_normal
from src.models.resnets import resnet18, resnet50
from src.models.resnext import resnext18
from src.models import ModelWithConfidence

In [3]:
DATA_DIR = "../data"
os.environ["L5KIT_DATA_FOLDER"] = DATA_DIR

In [4]:
cfg = {
    "format_version": 4,
    "model_params": {
        "history_num_frames": 10,
        "history_step_size": 1,
        "history_delta_time": 0.1,
        "future_num_frames": 50,
        "future_step_size": 1,
        "future_delta_time": 0.1,
    },
    "raster_params": {
        "raster_size": [384, 384],
        "pixel_size": [0.5, 0.5],
        "ego_center": [0.25, 0.5],
        "map_type": "py_semantic",
        "satellite_map_key": "aerial_map/aerial_map.png",
        "semantic_map_key": "semantic_map/semantic_map.pb",
        "dataset_meta_key": "meta.json",
        "filter_agents_threshold": 0.5,
    },
    "train_data_loader": {
        "key": "scenes/train.zarr",
        "batch_size": 12,
        "shuffle": True,
        "num_workers": 4,
    },
}

history_n_frames = cfg["model_params"]["history_num_frames"]
future_n_frames = cfg["model_params"]["future_num_frames"]
n_trajectories = 3
# model = ModelWithConfidence(
#     backbone=genet_normal(
#         in_channels=3 + (history_n_frames + 1) * 2,
#         num_classes=2 * future_n_frames * n_trajectories + n_trajectories,
#     ),
#     future_num_frames=future_n_frames,
#     num_trajectories=n_trajectories,
# )
model = ModelWithConfidence(
    backbone=resnet18(
        pretrained=False,
        in_channels=3 + 2 * (history_n_frames + 1),
        num_classes=2 * future_n_frames * n_trajectories + n_trajectories,
    ),
    future_num_frames=future_n_frames,
    num_trajectories=n_trajectories,
)
load_checkpoint("../logs/resnet18_bigerimages_continue4/stage_0/best.pth", model)
# load_checkpoint("../leo_checkpoints/best.pth", model)

<= Loaded model from '../logs/resnet18_bigerimages_continue4/stage_0/best.pth'
Stage: stage_0
Epoch: 1
Metrics:
{'train': {'loss': 12.515466623828052}, 'valid': {'loss': 12.515466623828052}}


In [5]:
dm = LocalDataManager(None)
rasterizer = build_rasterizer(cfg, dm)

test_zarr = ChunkedDataset(dm.require("scenes/test.zarr")).open()
test_mask = np.load(f"{DATA_DIR}/scenes/mask.npz")["arr_0"]

test_dataset = AgentDataset(
    cfg, test_zarr, rasterizer, agents_mask=test_mask
)
test_dataloader = DataLoader(
    test_dataset,
    shuffle=False,
    batch_size=32,
    num_workers=20
)

In [6]:
device = torch.device("cuda:0")
model = model.to(device)

In [7]:
model.eval()
torch.set_grad_enabled(False)

# store information for evaluation
future_coords_offsets_pd = []
timestamps = []
confidences_list = []
agent_ids = []

with tqdm(total=len(test_dataloader)) as progress:
    for data in test_dataloader:
        inputs = data['image'].to(device)

        preds, confidences  = model(inputs)
        
        # TODO: fix coordinates
        preds = preds.cpu().numpy().copy()
        world_from_agents = data["world_from_agent"].numpy()
        centroids = data["centroid"].numpy()
        for idx in range(len(preds)):
            for mode in range(n_trajectories):
                # FIX
                preds[idx, mode, :, :] = transform_points(preds[idx, mode, :, :], world_from_agents[idx]) - centroids[idx][:2]
        
        future_coords_offsets_pd.append(preds.copy())
        confidences_list.append(confidences.cpu().numpy().copy())
        timestamps.append(data["timestamp"].numpy().copy())
        agent_ids.append(data["track_id"].numpy().copy())

        progress.update(1)


predictions_file = "submission_resnet18_384x384_continue4_ofsets.csv"
write_pred_csv(
    predictions_file,
    timestamps=np.concatenate(timestamps),
    track_ids=np.concatenate(agent_ids),
    coords=np.concatenate(future_coords_offsets_pd),
    confs = np.concatenate(confidences_list)
)

100%|████████████████████| 2223/2223 [12:11<00:00,  3.04it/s]
