In [1]:
# Auto reload module
# https://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import yaml
import datetime
import argparse
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

# set tf to cpu only
import tensorflow as tf
tf.config.set_visible_devices([], "GPU")
import jax
jax.config.update("jax_platform_name", "cpu")

import sys
sys.path.append("/robin-west/VBD")

from vbd.data.dataset import WaymaxDataset
from vbd.model.VBD import VBD
from torch.utils.data import DataLoader

import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import WandbLogger, CSVLogger
from lightning.pytorch.strategies import DDPStrategy

from matplotlib import pyplot as plt

In [3]:
def load_config(file_path):
    with open(file_path, "r") as file:
        data = yaml.safe_load(file)
    return data

In [4]:
# config_path = "/robin-west/VBD/config/diffuse_ego_only/validate/vbd_ego_agent_future_len_40_input_action_normalize_true_prior_means_steer_and_speed_scale_8_no_cond_attn_ego_validate.yaml"
config_path = "/robin-west/VBD/config/diffuse_ego_only/validate/vbd_ego_agent_future_len_40_input_action_normalize_true_prior_means_steer_and_speed_scale_15_no_cond_attn_ego_validate.yaml"
# config_path = "/robin-west/VBD/config/diffuse_ego_only/validate/vbd_ego_agent_future_len_40_input_action_normalize_true_prior_means_steer_and_speed_scale_0_no_cond_attn_ego_validate.yaml"
# config_path = "/robin-west/VBD/config/diffuse_ego_only/validate/vbd_ego_agent_future_len_80_input_action_normalize_true_prior_means_steer_and_speed_scale_15_no_cond_attn_ego_validate.yaml"
# config_path = "/robin-west/VBD/config/_final_validate/vbd_ego_agent_future_len_40_input_action_normalize_true_prior_means_steer_and_speed_scale_0_no_cond_attn_ego_classifier_validate.yaml"
# config_path = "/robin-west/VBD/config/mean/VBD_train_on_full_dataset_validate.yaml"
# config_path = "/robin-west/VBD/config/_table_2/vbd_ego_agent_future_len_40_input_action_normalize_true_prior_means_steer_and_speed_scale_0_no_cond_attn_ego_classifier_validate_10.yaml"
# config_path = "/robin-west/VBD/config/_final_validate/vbd_ego_agent_future_len_40_input_action_normalize_true_prior_means_steer_and_speed_scale_15_no_cond_attn_ego_validate.yaml"

cfg = load_config(config_path)
cfg['num_workers'] = 1
cfg['batch_size'] = 1
dataset_dir = '/root/single_agent_subset/validation/processed'

In [5]:
pl.seed_everything(cfg["seed"])
torch.set_float32_matmul_precision("high")    
    
# create dataset
from vbd.data.dataset import WaymaxTestDataset

val_dataset = WaymaxTestDataset(
    data_dir=dataset_dir,
    future_len = cfg["future_len"],
    anchor_path=cfg["anchor_path"],
    predict_ego_only=cfg["predict_ego_only"],
    action_labels_path=cfg["validation_action_labels_path"],
    max_object= cfg["agents_len"],
)

# val_loader = DataLoader(
#     val_dataset, 
#     batch_size=cfg["batch_size"],
#     pin_memory=True, 
#     num_workers=cfg["num_workers"],
#     shuffle=False
# )

Seed set to 42


In [6]:
# from vbd.waymax_visualization.plotting import plot_ego, plot_state
# import mediapy

# ids = random.sample(range(len(val_dataset)), 10)
# for id in ids:
#     print(id)
#     scenario_id, scenario, data_dict = val_dataset.get_scenario_by_index(i)
#     scenario.object_metadata.is_controlled = scenario.object_metadata.is_sdc
#     print(data_dict['sdc_steer_label'], data_dict['sdc_speed_label'])
#     mediapy.show_image(plot_state(scenario), width=800)


In [7]:
from copy import deepcopy
import random

from vbd.waymax_visualization.plotting import plot_ego, plot_state
import mediapy

In [8]:
import jax
import matplotlib.pyplot as plt
import numpy as np
def plot_batch_x_t_with_all_labels(scenario,agents_interested,x_t_history, timestep, batch, run_ids=None):
    assert timestep < len(x_t_history)
    steer_label_name = {
        0: 'go straight',
        1: 'left turn',
        2: 'right turn',
        3: 'U-turn',
    }
    speed_label_name = {
        1: 'accelerate',
        2: 'decelerate',
        3: 'keep speed'
    }
    x_ts = x_t_history[timestep]
    if run_ids is None:
        run_ids = np.arange(x_ts.shape[0])
    for i in run_ids:
        steer_label = batch[i]['sdc_steer_label']
        speed_label = batch[i]['sdc_speed_label'] # + 1 # reminder: remove this +1 for model after 0224
        print('id_{}_steer_{}_speed_{}'.format(i, steer_label_name[steer_label], speed_label_name[speed_label]))
        traj = x_ts[i]
        img = plot_ego(
            scenario, 
            ego_traj = traj[agents_interested[i]>0].detach().cpu().numpy(), 
            log_traj=True
            )
        # torch.norm(denoiser_outputs['denoised_trajs'][agents_interested>0][:, -1, 2:4])
        mediapy.show_image(img, width=400)


In [None]:
def sample_with_random_id(val_dataset, i):
    if type(i) is int:
        scenario_id, scenario, data_dict = val_dataset.get_scenario_by_index(i)
    elif type(i) is str:
        scenario_id, scenario, data_dict = val_dataset.get_scenario_by_id(i)
    scenario.object_metadata.is_controlled = scenario.object_metadata.is_sdc
    print(data_dict['sdc_speed_label'], data_dict['sdc_steer_label'])

    batch = []
    for steer_label in [0, 1, 2, 3]:
        for speed_label in [1, 2, 3]:
            data_dict_ = deepcopy(data_dict)
            data_dict_['sdc_steer_label'] = steer_label
            data_dict_['sdc_speed_label'] = speed_label
            # data_dict_['sdc_speed_label'] -= 1      # reminder: delete after 0224 model
            batch.append(data_dict_)
    # batch = [data_dict]
    batch_collated = val_dataset.__collate_fn__(batch)

    model = VBD(cfg=cfg)
    ckpt_path = cfg.get("init_from", None)
    print("Load Weights from ", ckpt_path)
    model.load_state_dict(torch.load(ckpt_path, map_location=torch.device("cuda"))["state_dict"])

    model.cuda()
    log_dict, denoiser_outputs, agents_interested, x_t_history = model.sample_denoiser_for_plot(batch_collated, calc_loss=True)   # log_dict compute the mean over all labels
    print(log_dict)
    return scenario, agents_interested, x_t_history, batch

    plot_batch_x_t_with_all_labels(
        scenario = scenario,
        agents_interested = agents_interested,
        x_t_history = x_t_history, 
        timestep = timestep, 
        batch = batch,
        run_ids=None)

In [10]:
print(len(val_dataset))

2309


In [11]:
# [2233, 2232， 1728, 456]
# right turn, decelerate: 2228
# 2227, 2225, 2223
i = 2223#random.randint(0, len(val_dataset))         
print(i)
scenario, agents_interested, x_t_history, batch = sample_with_random_id(val_dataset, i)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


2223
2 2
Load Weights from  /robin-west/VBD/train_log_0225/vbd_ego_only_type_sample_schedule_linear_future_len_40_input_type_action_normalize_action_True_label_True_type_steer_and_speed_scale_0.0_cond_embed_None_diffuse_ego_True/epoch=63.ckpt
{}


In [None]:
timestep = -1
plot_batch_x_t_with_all_labels(
        scenario = scenario,
        agents_interested = agents_interested,
        x_t_history = x_t_history, 
        timestep = timestep, 
        batch = batch,
        run_ids=None
        )

In [14]:
# i = "6ece5198f531a353"
# _, _, _, data = val_dataset.get_scenario_by_id(i)


# def extract_high_level_motion_action(heading, acceleration):
#     if acceleration > 1: 
#         speed_action = 1 # acceleration
#     elif acceleration < -1:
#         speed_action = 2 # deceleration
#     else:
#         speed_action = 3 # keep speed
    
#     heading = np.rad2deg(heading)
#     if np.abs(heading) < 2.4:
#         steering_action = 0 # go straight
#     elif np.abs(heading) < 26.4: 
#         if heading > 0 :
#             steering_action = 1 # turn left
#         else: 
#             steering_action = 2 # turn right
#     else:
#         if heading > 0 :
#             steering_action = 3 # left u turn
#         else: 
#             steering_action = 2 # turn right
    
#     return np.array([speed_action, steering_action])


# def wrap_angle(angle):
#     """
#     Wrap the angle to [-pi, pi].

#     Args:
#         angle (torch.Tensor): Angle tensor.

#     Returns:
#         torch.Tensor: Wrapped angle.

#     """
#     # return torch.atan2(torch.sin(angle), torch.cos(angle))
#     return (angle + torch.pi) % (2 * torch.pi) - torch.pi


# def extract_patch_action(speed_patch, heading_patch):
#     ## no need for sdc
#     # first_valid_ts = -1
#     # last_valid_ts = -1
#     # for ts in range(valid_patch.shape[0]):
#     #     if first_valid_ts==-1 and valid_patch[ts]:
#     #         first_valid_ts = ts
#     #     elif first_valid_ts!=-1 and last_valid_ts==-1:
#     #         if not valid_patch[ts]:
#     #             last_valid_ts = ts - 1
#     #         elif ts == valid_patch.shape[0]-1:
#     #             last_valid_ts = ts
#     # if first_valid_ts==-1 and last_valid_ts==-1 or first_valid_ts==last_valid_ts:
#     #     return np.array([-1, -1], dtype=np.float32)
#     assert len(speed_patch) == len(heading_patch)
#     speed_diff = 10 * (speed_patch[-1] - speed_patch[0]) / len(speed_patch)
#     heading_diff = 10 * wrap_angle(heading_patch[-1] - heading_patch[0]) / len(heading_patch)
#     heading_diff = wrap_angle(heading_diff)
#     print(speed_patch[-1], speed_patch[0])
#     print(heading_diff, heading_patch[-1], heading_patch[0])
#     patch_action = extract_high_level_motion_action(heading_diff, speed_diff)
#     return patch_action


# def extract_patches_action(speed, heading, sample_rate=10):
#     high_level_action = []
#     for patch_id in range((speed.shape[0]) // sample_rate):
#         print("***")
#         speed_patch = speed[patch_id*sample_rate:(patch_id+1)*sample_rate]
#         heading_patch = heading[patch_id*sample_rate:(patch_id+1)*sample_rate]
#         high_level_action.append(extract_patch_action(speed_patch, heading_patch))
#     return np.stack(high_level_action, axis=0)


# def extract_sdc_action(data):
#     scenario = data['scenario_raw']
#     sdc_id = np.where(scenario.object_metadata.is_sdc)[0][0]
#     sdc_id_in_processed = np.where(data["agents_id"]==sdc_id)[0][0]
#     sdc_future = data["agents_future"][sdc_id_in_processed]
#     assert sdc_future.shape[0] == 81 and sdc_future.shape[1] == 5, "sdc future traj shape is wrong"
#     vel_xy = sdc_future[:, 3:]
#     speed = np.linalg.norm(vel_xy, axis=-1)
#     heading = sdc_future[:, 2]
#     # sdc_future_actions_4s = extract_patches_action(speed, heading, sample_rate=40)
#     sdc_future_actions_1s = extract_patches_action(speed, heading, sample_rate=10)
#     # return sdc_id, sdc_future_actions_4s, sdc_future_actions_1s

In [1]:
from glob import glob 
import pickle
import numpy as np
from tqdm import tqdm

def table_2_stats(results_dir):
    scenario_results_path = glob(results_dir)
    print(scenario_results_path[0])
    speed_acc_ds = []
    steer_acc_ds = []
    combined_acc_ds = []
    OR = []
    steer_key = []
    speed_key = []

    for scenario_result_path in tqdm(scenario_results_path):
        with open(scenario_result_path, 'rb') as scenario_result_f:
            scenario_result = pickle.load(scenario_result_f) 
        # print(scenario_result.keys())
        speed_acc = np.mean(scenario_result['speed_acc'])
        steer_acc = np.mean(scenario_result['steer_acc'])
        speed_acc_ds.append(speed_acc)
        steer_acc_ds.append(steer_acc)
        speed_key.append(scenario_result['speed_label'].detach().cpu().numpy())
        steer_key.append(scenario_result['steer_label'].detach().cpu().numpy())
        combined_acc = np.mean(scenario_result['speed_acc'] * scenario_result['steer_acc'])
        combined_acc_ds.append(combined_acc)
        OR.append(scenario_result['offroad'].mean())
    return speed_acc_ds, steer_acc_ds, speed_key, steer_key, OR, combined_acc_ds

In [4]:
results_dir_dict = {
    '1': '/robin-west/VBD/config/_table_2/results/scale_0.0_cond_None_means_type_steer_and_speed_gradients_scale_1/*.pkl',
    '01': '/robin-west/VBD/config/_table_2/results/scale_0.0_cond_None_means_type_steer_and_speed_gradients_scale_0.1/*.pkl',
    '10': '/robin-west/VBD/config/_table_2/results/scale_0.0_cond_None_means_type_steer_and_speed_gradients_scale_10/*.pkl',
    'c': '/robin-west/VBD/config/_table_2/results/scale_0.0_cond_20_means_type_steer_and_speed_gradients_scale_0.0/*.pkl',
    'ms8': '/robin-west/VBD/config/_table_2/results/scale_8.0_cond_None_means_type_steer_and_speed_gradients_scale_0.0/*.pkl',
    'ms15': '/robin-west/VBD/config/_table_2/results/scale_15.0_cond_None_means_type_steer_and_speed_gradients_scale_0.0/*.pkl',
}

combined_results = {}
for key in results_dir_dict:
    results_dir = results_dir_dict[key]
    speed_acc_ds, steer_acc_ds, speed_key, steer_key, OR, combined_acc = table_2_stats(results_dir)
    combined_results[key] = {
        'steer_key': steer_key,
        'speed_key': speed_key,
        'steer_acc': steer_acc_ds,
        'speed_acc': speed_acc_ds,
        'offroad': OR,
        'acc': combined_acc,
    }
    print(key,np.mean(speed_acc_ds), np.mean(steer_acc_ds), np.mean(OR))

/robin-west/VBD/config/_table_2/results/scale_0.0_cond_None_means_type_steer_and_speed_gradients_scale_1/1350abaa358b6f2a.pkl


  0%|          | 0/2309 [00:00<?, ?it/s]

100%|██████████| 2309/2309 [00:44<00:00, 52.20it/s] 


1 0.3490688609787787 0.40291612530677057 0.2887253
/robin-west/VBD/config/_table_2/results/scale_0.0_cond_None_means_type_steer_and_speed_gradients_scale_0.1/1350abaa358b6f2a.pkl


100%|██████████| 2309/2309 [04:30<00:00,  8.52it/s] 


01 0.3269813772195756 0.26403926663779415 0.19792117
/robin-west/VBD/config/_table_2/results/scale_0.0_cond_None_means_type_steer_and_speed_gradients_scale_10/1350abaa358b6f2a.pkl


100%|██████████| 2309/2309 [05:09<00:00,  7.46it/s]


10 0.4280352244839035 0.4192291035080121 0.37924066
/robin-west/VBD/config/_table_2/results/scale_0.0_cond_20_means_type_steer_and_speed_gradients_scale_0.0/1350abaa358b6f2a.pkl


100%|██████████| 2309/2309 [05:28<00:00,  7.04it/s]


c 0.6955391944564747 0.23126894759636207 0.31009096
/robin-west/VBD/config/_table_2/results/scale_8.0_cond_None_means_type_steer_and_speed_gradients_scale_0.0/1350abaa358b6f2a.pkl


100%|██████████| 2309/2309 [08:18<00:00,  4.63it/s]


ms8 0.4660025985275011 0.23054713440161684 0.14436264
/robin-west/VBD/config/_table_2/results/scale_15.0_cond_None_means_type_steer_and_speed_gradients_scale_0.0/1350abaa358b6f2a.pkl


100%|██████████| 2309/2309 [05:46<00:00,  6.67it/s]

ms15 0.5484336653674029 0.23285693662480147 0.14176412





In [3]:
with open('/robin-west/VBD/config/_table_2/results/table2.pkl', 'wb') as table_2_f:
    pickle.dump(combined_results, table_2_f)

In [6]:
from copy import deepcopy
ref = deepcopy(combined_results['ms15'])

In [16]:
mask = np.asarray(ref['offroad']) < 0.4

np.mean(np.asarray(ref['offroad'])[mask])

0.0043771043

In [17]:
np.asarray(ref['offroad'])

array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)

In [55]:
# 0.1:  0.3269813772195756 0.26403926663779415
# 1: 0.3490688609787787 0.40291612530677057
# 10: (0.4280352244839037, 0.41922910350801257)
# condition: 0.6955391944564747 0.23126894759636207
# ours scale 8 (0.47667198864644306, 0.23054713440161684
# ours scale15 (0.5484336653674027, 0.2328569366248014)

0.8, 0.6

(0.8, 0.6)

In [5]:
import pickle
import torch
with open('/robin-west/VBD/config/_table_2/results/scale_8.0_cond_None_means_type_steer_and_speed_gradients_scale_0.0/1e619f364eb0d19c.pkl', 'rb') as f:
    results = pickle.load(f)

In [6]:
results

{'ADE': array([0.6624671 , 0.43391162, 0.6664012 , 5.1569314 , 4.5021873 ,
        4.0214095 , 2.3454378 , 2.4551225 , 2.2114265 , 3.6436558 ,
        4.108943  , 4.1375575 , 6.824962  , 6.7768946 , 6.6980376 ,
        5.436253  , 6.2017007 , 5.315212  , 1.0896425 , 1.6023048 ,
        1.2920474 , 6.2474937 , 6.92875   , 6.4747086 , 5.86187   ,
        5.315156  , 6.645833  , 6.7562876 , 6.6236224 , 5.895626  ,
        2.1064236 , 1.311551  , 1.4465517 , 4.9791474 , 5.480845  ,
        3.8807266 ], dtype=float32),
 'FDE': array([ 1.5384046,  1.2605006,  1.4976586, 15.010744 , 13.680626 ,
        12.539918 ,  7.182234 ,  6.925472 ,  6.7745695, 10.522559 ,
        11.302546 , 11.236442 , 19.63363  , 19.611345 , 19.418152 ,
        16.109837 , 18.127075 , 15.389492 ,  3.9241576,  5.5619245,
         4.794102 , 18.313982 , 19.82665  , 18.903933 , 17.882793 ,
        16.568302 , 20.59802  , 19.381147 , 18.805323 , 17.67942  ,
         6.403363 ,  4.4282694,  5.2266417, 14.616061 , 15.446487