In [None]:
%matplotlib inline
import sys
sys.path.append('../../trajectron')
import os
import numpy as np
import torch
import dill
import json
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import matplotlib.patheffects as pe
from helper import *
import visualization
import statistics
from collections import Counter

from copy import deepcopy

from sklearn.metrics import mean_squared_error

## Load nuScenes SDK and data

In [None]:
nuScenes_data_path = 'v1.0-trainval'    # Data Path to nuScenes data set 
nuScenes_devkit_path = './devkit/python-sdk/'
sys.path.append(nuScenes_devkit_path)
from nuscenes.map_expansion.map_api import NuScenesMap
nusc_map = NuScenesMap(dataroot=nuScenes_data_path, map_name='boston-seaport')

# Map Encoding Demo

In [None]:
file_path = '../processed/nuScenes_test_full.pkl'
with open(file_path, 'rb') as f:
    eval_env = dill.load(f, encoding='latin1')
eval_scenes = eval_env.scenes


In [None]:
ph = 6
log_dir = './models'

In [None]:
# Define ROI in nuScenes Map
x_min = 773.0
x_max = 1100.0
y_min = 1231.0
y_max = 1510.0

In [None]:
layers = ['drivable_area',
          'road_segment',
          'lane',
          'ped_crossing',
          'walkway',
          'stop_line',
          'road_divider',
          'lane_divider']

## Prediction using velocity output

In [None]:
PERTURBATIONS = {
    'x_q1':(-0.0398, 0.0071),
    'x_q3':(0.0008, -0.0213),
    'y_q1':(-0.0519, -0.0247),
    'y_q3':(-0.0613, 0.0152),
    'x_uf':(0.0617, -0.0283),
    'x_lf':(-0.1018, -0.0426),
    'y_uf':(-0.0471, 0.0752),
    'y_lf':(-0.0250, -0.0852),
    'x_out':(-0.3349, -0.5756),
    'y_out':(-0.1250, -0.1358),
    'x_min':(-5.8995, -3.1643),
    'xy_max':(0.5272, 0.4045),
    'y_min':(-5.3667, -3.2429),
}

In [None]:
counter = []
for sid, scene in enumerate(eval_env.scenes):
    for ni, n in enumerate(scene.nodes):
        if str(n.id) != 'ego':
            counter.append(str(n.type))
print(f'Obstacle count {Counter(counter)}')

In [None]:
model_dir = os.path.join(log_dir, 'vel_ee') 
eval_stg_vel, hyp = load_model(model_dir, eval_env, ts=12)

print(f'Number of scenes = {len(eval_scenes)}')
scenes = eval_scenes
ph = 6
with torch.no_grad():
    timestep = np.array([2])

    output = {}
    for perturbation, pvalue in PERTURBATIONS.items():
        print(f'In round {perturbation}')
        for sid, scene in enumerate(scenes):
            scene_perturb = deepcopy(scene)
            for ni, n in enumerate(scene.nodes):
                if str(n.id) != 'ego':
                    to_del = -1
                    for di, _ in enumerate(n.data.data):
                        # modify all
                        scene_perturb.nodes[ni].data.data[di][0] += pvalue[0]
                        scene_perturb.nodes[ni].data.data[di][1] += pvalue[1]

            predictions_mm = eval_stg_vel.predict(scene,
                                                timestep, ph,
                                                num_samples=1,
                                                z_mode=True, gmm_mode=True)

            predictions_mm_perturb = eval_stg_vel.predict(scene_perturb,
                                                        timestep, ph,
                                                        num_samples=1,
                                                        z_mode=True, gmm_mode=True)   

            # values will always be equal to timestep above
            pkey = list(predictions_mm_perturb.keys())[0]
            output[scene.name] = {}
            for node in predictions_mm_perturb[pkey].keys():
                output[scene.name][str(node)] = {
                    'original': predictions_mm[pkey][node].tolist(),
                    'perturbed': predictions_mm_perturb[pkey][node].tolist()
                }
        # saving data
            if (sid + 1) % 20 == 0:
                print(f'Saving at index {sid}, scene {scene.name}')
                with open(f'perturbated_results/change_all/{perturbation}/saved_at_{scene.name}.json', 'w') as fd:
                    json.dump(output, fd)
        with open(f'perturbated_results/change_all/{perturbation}/saved_final.json', 'w') as fd:
            json.dump(output, fd)

## Process output

In [None]:
def compute_error(data_baseline, perturbation, perturbation_type):

    final_file = f'perturbated_results/{perturbation_type}/{perturbation}/saved_final.json'
    if perturbation_type == 'remove_once':
        final_file = f'perturbated_results/{perturbation_type}/saved_final.json'
    with open(final_file, 'r') as fd:
        data = json.loads(fd.read())

        ade_x_2 = []
        ade_y_2 = []
        ade_dist_2 = []

        fde_x_2 = []
        fde_y_2 = []
        for ts, _ in data.items():
            # print(ts)
            for obs, _ in data[ts].items():
                if 'ego' not in obs:
                # if 'PEDESTRIAN' in obs:
                # if 'VEHICLE' in obs:
                    orig = data_baseline[ts][obs]['original'][0][0]
                    pert = data[ts][obs]['perturbed'][0][0]
                    # ade
                    ade_dist_2.append(mean_squared_error(pert, orig))

                    orig_x = [i[0] for i in orig]
                    pert_x = [i[0] for i in pert]

                    orig_y = [i[1] for i in orig]
                    pert_y = [i[1] for i in pert]

                    ade_x_2.append(mean_squared_error(orig_x, pert_x, squared=False))
                    ade_y_2.append(mean_squared_error(orig_y, pert_y, squared=False))

                    # fde
                    orig_x = [orig[-1][0]]
                    pert_x = [pert[-1][0]]

                    orig_y = [orig[-1][1]]
                    pert_y = [pert[-1][1]]
                    fde_x_2.append(mean_squared_error(orig_x, pert_x, squared=False))
                    fde_y_2.append(mean_squared_error(orig_y, pert_y, squared=False))

        print(perturbation, round(np.quantile(ade_x_2, .99, interpolation='nearest'), 4), round(np.quantile(ade_y_2, .99, interpolation='nearest'), 4))
        print(perturbation, round(np.quantile(fde_x_2, .99, interpolation='nearest'), 4), round(np.quantile(fde_y_2, .99, interpolation='nearest'), 4))

data_baseline = None

with open(f'perturbated_results/change_all/xy_max/saved_final.json', 'r') as fd: 
    data_baseline = json.loads(fd.read())

for perturbation_type in ['change_all']:
    print(perturbation_type, 'ade', 'fde')
    for perturbation, _ in PERTURBATIONS.items():
        compute_error(data_baseline, perturbation, perturbation_type)