In [1]:
from typing import Dict

from tempfile import gettempdir
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.models.resnet import resnet50
from tqdm import tqdm

from l5kit.configs import load_config_data
from l5kit.data import LocalDataManager, ChunkedDataset
from l5kit.dataset import AgentDataset, EgoDataset
from l5kit.rasterization import build_rasterizer
from l5kit.evaluation import write_pred_csv, compute_metrics_csv, read_gt_csv, create_chopped_dataset
from l5kit.evaluation.chop_dataset import MIN_FUTURE_STEPS
from l5kit.evaluation.metrics import neg_multi_log_likelihood, time_displace
from l5kit.geometry import transform_points
from l5kit.visualization import PREDICTED_POINTS_COLOR, TARGET_POINTS_COLOR, draw_trajectory
from prettytable import PrettyTable
from pathlib import Path
import os
import pandas as pd
import argparse
import pickle

In [2]:
exp_id = "efficientB0_frame_10"

In [14]:
cfg = {
    'format_version': 4,
    'data_path': "/home/axot/lyft/data",
    'model_params': {
        'model_architecture': 'efficientnet-b1',
        'history_num_frames': 10,
        'future_num_frames': 50,
        'lr': 1e-4,
        'history_step_size': 1,
        'history_delta_time': 0.1,
        'future_step_size': 1,
        'future_delta_time': 0.1,
    },
    'raster_params': {
        'raster_size': [224, 224],
        'pixel_size': [0.5, 0.5],
        'ego_center': [0.25, 0.5],
        'map_type': 'py_semantic',
        'satellite_map_key': 'aerial_map/aerial_map.png',
        'semantic_map_key': 'semantic_map/semantic_map.pb',
        'dataset_meta_key': 'meta.json',
        'filter_agents_threshold': 0.5,
        'disable_traffic_light_faces': False
    },
    'train_data_loader': {
        'key': 'scenes/train.zarr',
        'batch_size': 8,
        'shuffle': True,
        'num_workers': 4
    },
    'test_data_loader': {
        'key': 'scenes/test.zarr',
        'batch_size': 16,
        'shuffle': False,
        'num_workers': 4
    }
}

In [15]:
def get_dm():
    # set env variable for data
    DIR_INPUT = cfg["data_path"]
    os.environ["L5KIT_DATA_FOLDER"] = DIR_INPUT
    dm = LocalDataManager(None)
    return dm

In [16]:
def load_test_data():
    dm = get_dm()
    test_cfg = cfg["test_data_loader"]
    _cfg = cfg
    _cfg["model_params"]['history_num_frames'] = 10
    rasterizer = build_rasterizer(_cfg, dm)
    test_zarr = ChunkedDataset(dm.require(test_cfg["key"])).open()
    test_mask = np.load(f'{_cfg["data_path"]}/scenes/mask.npz')["arr_0"]
    test_dataset = AgentDataset(_cfg, test_zarr, rasterizer, agents_mask=test_mask)
    print(_cfg)
    return test_dataset

In [17]:
test_dataset = load_test_data()

{'format_version': 4, 'data_path': '/home/axot/lyft/data', 'model_params': {'model_architecture': 'efficientnet-b1', 'history_num_frames': 10, 'future_num_frames': 50, 'lr': 0.0001, 'history_step_size': 1, 'history_delta_time': 0.1, 'future_step_size': 1, 'future_delta_time': 0.1}, 'raster_params': {'raster_size': [224, 224], 'pixel_size': [0.5, 0.5], 'ego_center': [0.25, 0.5], 'map_type': 'py_semantic', 'satellite_map_key': 'aerial_map/aerial_map.png', 'semantic_map_key': 'semantic_map/semantic_map.pb', 'dataset_meta_key': 'meta.json', 'filter_agents_threshold': 0.5, 'disable_traffic_light_faces': False}, 'train_data_loader': {'key': 'scenes/train.zarr', 'batch_size': 8, 'shuffle': True, 'num_workers': 4}, 'test_data_loader': {'key': 'scenes/test.zarr', 'batch_size': 16, 'shuffle': False, 'num_workers': 4}}


  test_dataset = AgentDataset(_cfg, test_zarr, rasterizer, agents_mask=test_mask)


In [7]:
len(test_dataset)

71122

In [15]:
# stop_indexes = []
# stop_positions = []
# for i, data_agent in tqdm(enumerate(test_dataset)):
#     h_pos = data_agent["history_positions"]
#     history_move = np.sum(np.sum((h_pos[:-1] - h_pos[1:])**2, axis=1) ** 0.5 )
#     if history_move < 1.45:
#         out_pos = np.tile([data_agent["history_positions"][0]], (150, 1) )
#         stop_indexes.append(i)
#         stop_positions.append(out_pos)

71122it [45:37, 25.98it/s]


In [52]:
np.set_printoptions(precision=4, suppress=True)
data_agent = test_dataset[stop_indexes[8]]

h_pos = data_agent["history_positions"]
avail = data_agent['history_availabilities']
history_move = np.sum(np.sum((h_pos[:-1] - h_pos[1:])**2, axis=1) ** 0.5 )
print(avail)
print(h_pos)
print(history_move)

[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
[[ 0.      0.    ]
 [-0.112   0.0116]
 [-0.2388 -0.0178]
 [-0.3407 -0.0144]
 [-0.4404 -0.0432]
 [-0.5406 -0.0565]
 [-0.6422 -0.0383]
 [-0.713  -0.0128]
 [-0.7885 -0.0389]
 [-0.8704 -0.0617]
 [-0.9226 -0.0388]]
0.94989896


In [None]:
print(len(stop_indexes))
# print(len(stop_positions))

In [23]:
# file = open('./refine_data.pkl', 'rb')
stop_indexes = pickle.load(open('./refine_data.pkl', 'rb'))['indexes']

In [21]:
stop_indexes.shape

(46177,)

In [24]:
stop_positions = np.zeros((46177, 150, 2))

In [28]:
data_agent.keys()

dict_keys(['image', 'target_positions', 'target_yaws', 'target_availabilities', 'history_positions', 'history_yaws', 'history_availabilities', 'world_to_image', 'raster_from_world', 'raster_from_agent', 'agent_from_world', 'world_from_agent', 'track_id', 'timestamp', 'centroid', 'yaw', 'extent'])

In [35]:
data_agent = test_dataset[stop_indexes[10000]]
h_pos = data_agent["history_positions"]
centroid = data_agent["centroid"]
world_from_agent = data_agent['world_from_agent']

t_pos = transform_points(np.zeros((150, 2)), world_from_agent) - centroid

print(world_from_agent)
print(centroid)
print(t_pos)
print()

[[ 6.81133625e-01  7.32159125e-01 -5.36391602e+02]
 [-7.32159125e-01  6.81133625e-01  9.01450439e+02]
 [ 0.00000000e+00  0.00000000e+00  1.00000000e+00]]
[-536.39160156  901.45043945]
[[0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0

In [None]:
for i in stop_indexes:
    data_agent = test_dataset[i]
    
    h_pos = data_agent["history_positions"]
    centroid = data_agent["centroid"]
    world_from_agent = data_agent['world_from_agents']
    
    history_move = np.sum(np.sum((h_pos[:-1] - h_pos[1:])**2, axis=1) ** 0.5 )

#     stop_positions[idx, mode, :, :] = transform_points(preds[idx, mode, :, :], world_from_agents[idx]) - centroids[:2]

In [20]:
stop_indexes = np.array(stop_indexes)
stop_positions = np.array(stop_positions)

In [40]:
stop_indexes[:5]

array([ 4,  7, 11, 12, 13])

In [42]:
np.where(stop_positions>1e-3)

(array([], dtype=int64), array([], dtype=int64), array([], dtype=int64))

In [None]:
# def get_history_moves(h_pos):
    
# np.sum(np.sum((h_pos[:-1] - h_pos[1:])**2, axis=1) ** 0.5 )

In [44]:
test_dataset[11]['history_positions']

array([[ 2.2737368e-13,  0.0000000e+00],
       [ 1.1562992e-02, -2.5197869e-02],
       [ 1.6639808e-02, -7.4840821e-02],
       [ 2.7471296e-02, -6.6522360e-02],
       [ 2.1586115e-02, -7.4350737e-02],
       [-2.8913293e-02, -6.6845998e-02],
       [-2.6989944e-02, -6.6144802e-02],
       [-4.0805716e-02, -6.4046197e-02],
       [-4.2343091e-02, -5.1509514e-02],
       [-5.7316624e-02, -5.0937016e-02],
       [-3.6376778e-02, -5.1554024e-02]], dtype=float32)

In [21]:
import pickle
save_data = {
    'indexes': stop_indexes,
    'positions': stop_positions
}
pickle.dump(save_data, open('refine_data.pkl', 'wb'))

In [45]:
stop_positions = np.zeros((46177, 300))

In [59]:
import pandas as pd

def get_df(exp_id):
    df = pd.read_csv(f'./experiment/{exp_id}/submission.csv')
    return df

def save_refined_df(df, exp_id):
    df.to_csv(f'./experiment/{exp_id}/refined_submission.csv', index=None)

    
def refine_df(df):
    df.iloc[stop_indexes, 5:] = stop_positions
    return df

In [60]:
df = get_df(exp_id)

In [62]:
r_df = refine_df(df.copy())
r_df

Unnamed: 0,timestamp,track_id,conf_0,conf_1,conf_2,coord_x00,coord_y00,coord_x01,coord_y01,coord_x02,...,coord_x245,coord_y245,coord_x246,coord_y246,coord_x247,coord_y247,coord_x248,coord_y248,coord_x249,coord_y249
0,1578606007801600134,2,0.624708,0.163838,0.211454,-0.00600,0.06882,-0.04595,0.10022,-0.09296,...,7.31227,15.02104,7.73638,15.38230,8.17710,15.82778,8.61344,16.20237,9.12179,16.57242
1,1578606032802467516,4,0.126553,0.458201,0.415246,-0.52732,-0.82176,-1.03953,-1.67309,-1.56786,...,-25.03226,-46.11573,-25.55175,-47.14069,-25.98598,-48.05925,-26.50943,-49.09786,-26.99852,-50.07749
2,1578606032802467516,5,0.513072,0.285902,0.201026,0.18800,0.25058,0.35100,0.49881,0.51114,...,14.98896,19.71581,15.39329,20.25641,15.74718,20.74393,16.16434,21.29050,16.56127,21.79739
3,1578606032802467516,81,0.519961,0.283230,0.196809,-0.18827,-0.24451,-0.35145,-0.49561,-0.51428,...,-14.78880,-19.57369,-15.17888,-20.10843,-15.52205,-20.58697,-15.93133,-21.11447,-16.32040,-21.61153
4,1578606032802467516,130,0.915495,0.055080,0.029425,0.00000,0.00000,0.00000,0.00000,0.00000,...,0.00000,0.00000,0.00000,0.00000,0.00000,0.00000,0.00000,0.00000,0.00000,0.00000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
71117,1583863581802681726,253,0.412214,0.369233,0.218554,0.40373,0.47590,0.78691,0.95914,1.18235,...,20.40412,22.99176,20.80225,23.46953,21.14073,23.87872,21.56821,24.37174,21.96772,24.79383
71118,1583863606802928836,1,0.225667,0.300430,0.473904,-0.13805,-0.67612,-0.26631,-1.27793,-0.40811,...,-14.21443,-31.86234,-14.58696,-32.56089,-14.97966,-33.20431,-15.36167,-33.96953,-15.80678,-34.58956
71119,1583863606802928836,6,0.686283,0.200259,0.113459,-0.05951,-0.08092,-0.09834,-0.17434,-0.12532,...,-8.39365,-17.60982,-8.64910,-18.15834,-8.86565,-18.63879,-9.12738,-19.15734,-9.36416,-19.68040
71120,1583863606802928836,213,0.624229,0.233732,0.142039,-0.04944,0.09936,-0.12714,0.18009,-0.20455,...,-8.62832,17.12661,-8.84900,17.61344,-9.04315,18.08752,-9.22708,18.58856,-9.40896,19.07559


In [63]:
r_df.to_csv('refine_df.csv', index=None)

In [None]:
# r_df.to_csv('refine_df.csv', index=None)