In [1]:
import warnings
warnings.filterwarnings('ignore');

In [2]:
import os
import sys
import numpy as np
import pandas as ps

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler


from l5kit.data import LocalDataManager, ChunkedDataset
from l5kit.dataset import AgentDataset, EgoDataset
from l5kit.rasterization import build_rasterizer
from l5kit.evaluation import write_pred_csv

sys.path.append("..")

from src.batteries import t2d, load_checkpoint
from src.batteries.progress import tqdm
from src.models.genet import genet_normal

In [3]:
DATA_DIR = "../data"
os.environ["L5KIT_DATA_FOLDER"] = DATA_DIR

In [4]:
cfg = {
    "format_version": 4,
    "model_params": {
        "history_num_frames": 10,
        "history_step_size": 1,
        "history_delta_time": 0.1,
        "future_num_frames": 50,
        "future_step_size": 1,
        "future_delta_time": 0.1,
    },
    "raster_params": {
        "raster_size": [224, 224],
        "pixel_size": [0.5, 0.5],
        "ego_center": [0.25, 0.5],
        "map_type": "py_semantic",
        "satellite_map_key": "aerial_map/aerial_map.png",
        "semantic_map_key": "semantic_map/semantic_map.pb",
        "dataset_meta_key": "meta.json",
        "filter_agents_threshold": 0.5,
    },
    "train_data_loader": {
        "key": "scenes/train.zarr",
        "batch_size": 12,
        "shuffle": True,
        "num_workers": 4,
    },
}

model = genet_normal(
    in_channels=3 + (cfg["model_params"]["history_num_frames"] + 1) * 2,
    num_classes=2 * cfg["model_params"]["future_num_frames"],
)
load_checkpoint("../logs/genet_normal/stage_0/best.pth", model)

<= Loaded model from '../logs/genet_normal/stage_0/best.pth'
Stage: stage_0
Epoch: 41
Metrics:
{'train': {'loss': 0.8472687861264248}, 'valid': {'loss': 0.8384244911799765}}


In [5]:
dm = LocalDataManager(None)
rasterizer = build_rasterizer(cfg, dm)

test_zarr = ChunkedDataset(dm.require("scenes/test.zarr")).open()
test_mask = np.load(f"{DATA_DIR}/scenes/mask.npz")["arr_0"]

test_dataset = AgentDataset(
    cfg, test_zarr, rasterizer, agents_mask=test_mask
)
test_dataloader = DataLoader(
    test_dataset,
    shuffle=False,
    batch_size=64,
    num_workers=16
)

In [6]:
device = torch.device("cuda:0")
model = model.to(device)

In [7]:
model.eval()

future_coords_offsets_pd = []
timestamps = []
agent_ids = []

with torch.no_grad(), tqdm(total=len(test_dataloader)) as progress:
    for data in test_dataloader:
        inputs = data["image"].to(device)
        target_availabilities = data["target_availabilities"].unsqueeze(-1).to(device)
        targets = data["target_positions"].to(device)

        outputs = model(inputs).reshape(targets.shape)
        
        future_coords_offsets_pd.append(outputs.cpu().numpy().copy())
        timestamps.append(data["timestamp"].numpy().copy())
        agent_ids.append(data["track_id"].numpy().copy())

        progress.update(1)


#create submission to submit to Kaggle
predictions_file = "submission_mse.csv"
write_pred_csv(
    predictions_file,
    timestamps=np.concatenate(timestamps),
    track_ids=np.concatenate(agent_ids),
    coords=np.concatenate(future_coords_offsets_pd)
)

100%|████████████████████| 1112/1112 [04:15<00:00,  4.35it/s]
