In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
# disable GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import tensorflow as tf
# set tf to cpu only
tf.config.set_visible_devices([], 'GPU')
import jax
jax.config.update('jax_platform_name', 'cpu')

import glob
import argparse
import pickle
# from vbd.data.data_utils import *
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map  # or thread_map

from waymax import dataloader
from waymax.config import DataFormat
import functools

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset_type = "validation"
data_dir = f"/robin-west/womd_processed/vbd/{dataset_type}/processed/*.pkl"
data_file_list = glob.glob(data_dir)

In [None]:
def extract_high_level_motion_action(heading, acceleration):
    if acceleration > 1: 
        speed_action = 1 # acceleration
    elif acceleration < -1:
        speed_action = 2 # deceleration
    else:
        speed_action = 3 # keep speed
    
    heading = np.rad2deg(heading)
    if np.abs(heading) < 2.4:
        steering_action = 0 # go straight
    elif np.abs(heading) < 26.4: 
        if heading > 0 :
            steering_action = 1 # turn left
        else: 
            steering_action = 2 # turn right
    else:
        if heading > 0 :
            steering_action = 3 # left u turn
        else: 
            steering_action = 2 # turn right
    
    return np.array([speed_action, steering_action])

"""
steer 
0: [-2.4, 2.4]
1: [2.4, 26.4]
2: [-inf, -2.4]
3: [26.4, inf]

speed
1: [1., inf]
2: [-inf, -1.]
3: [-1, 1]
"""


def extract_patch_action(speed_patch, heading_patch):
    ## no need for sdc
    # first_valid_ts = -1
    # last_valid_ts = -1
    # for ts in range(valid_patch.shape[0]):
    #     if first_valid_ts==-1 and valid_patch[ts]:
    #         first_valid_ts = ts
    #     elif first_valid_ts!=-1 and last_valid_ts==-1:
    #         if not valid_patch[ts]:
    #             last_valid_ts = ts - 1
    #         elif ts == valid_patch.shape[0]-1:
    #             last_valid_ts = ts
    # if first_valid_ts==-1 and last_valid_ts==-1 or first_valid_ts==last_valid_ts:
    #     return np.array([-1, -1], dtype=np.float32)
    assert len(speed_patch) == len(heading_patch)
    speed_diff = 10 * (speed_patch[-1] - speed_patch[0]) / len(speed_patch)
    heading_diff = 10 * (heading_patch[-1] - heading_patch[0]) / len(heading_patch)
    patch_action = extract_high_level_motion_action(heading_diff, speed_diff)
    return patch_action


def extract_patches_action(speed, heading, sample_rate=10):
    high_level_action = []
    for patch_id in range((speed.shape[0]) // sample_rate):
        speed_patch = speed[patch_id*sample_rate:(patch_id+1)*sample_rate]
        heading_patch = heading[patch_id*sample_rate:(patch_id+1)*sample_rate]
        high_level_action.append(extract_patch_action(speed_patch, heading_patch))
    return np.stack(high_level_action, axis=0)

In [4]:
def extract_sdc_action(data):
    scenario = data['scenario_raw']
    sdc_id = np.where(scenario.object_metadata.is_sdc)[0][0]
    sdc_id_in_processed = np.where(data["agents_id"]==sdc_id)[0][0]
    sdc_future = data["agents_future"][sdc_id_in_processed]
    assert sdc_future.shape[0] == 81 and sdc_future.shape[1] == 5, "sdc future traj shape is wrong"
    vel_xy = sdc_future[:, 3:]
    speed = np.linalg.norm(vel_xy, axis=-1)
    heading = sdc_future[:, 2]
    sdc_future_actions_4s = extract_patches_action(speed, heading, sample_rate=40)
    sdc_future_actions_1s = extract_patches_action(speed, heading, sample_rate=10)
    return sdc_id, sdc_future_actions_4s, sdc_future_actions_1s

In [5]:
action_labels = dict() 
for data_file_path in tqdm(data_file_list):
    scenario_id = data_file_path.split("/")[-1].rstrip(".pkl").split("_")[-1]
    with open(data_file_path, "rb") as data_f:
        data = pickle.load(data_f)
    sdc_id, sdc_future_actions_4s, sdc_future_actions_1s = extract_sdc_action(data)
    action_labels[scenario_id] = {
        'sdc_id': sdc_id,
        '1s_action': sdc_future_actions_1s,
        '4s_action': sdc_future_actions_4s,
    }
with open(f"/robin-west/womd_processed/vbd/{dataset_type}/action_labels.pkl", "wb") as action_labels_f:
    pickle.dump(action_labels, action_labels_f)

  0%|          | 0/486995 [00:00<?, ?it/s]An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
  0%|          | 126/486995 [00:29<31:43:04,  4.26it/s]


KeyboardInterrupt: 

In [6]:
len(data_file_list)

486995

In [3]:
num_buckets = 10
buckets = {i: [] for i in range(num_buckets)}

for idx, file_path in enumerate(data_file_list):
    bucket_index = idx % num_buckets
    buckets[bucket_index].append(file_path)

In [4]:
for k in buckets.keys():
    print(k)
    print(len(buckets[k]))

0
4410
1
4410
2
4410
3
4410
4
4410
5
4410
6
4410
7
4409
8
4409
9
4409


In [5]:
with open(f"/robin-west/VBD/script/{dataset_type}_file_buckets.pkl", "wb") as buckets_f:
    pickle.dump(buckets, buckets_f)

In [5]:
# merge buckets
import glob 
import pickle
dataset_type = "training"
bucket_dir = f"/robin-west/womd_processed/vbd/{dataset_type}/action_labels_bucket_*.pkl"
bucket_file_list = glob.glob(bucket_dir)
merged_action_labels_path = f"/robin-west/womd_processed/vbd/{dataset_type}/action_labels.pkl"

In [6]:
merged_action_labels = dict()
for bucket_file_path in bucket_file_list:
    with open(bucket_file_path, "rb") as bucket_f:
        bucket_action_labels = pickle.load(bucket_f)
    merged_action_labels.update(bucket_action_labels)

In [7]:
with open(merged_action_labels_path, "wb") as merged_action_labels_f:
    pickle.dump(merged_action_labels, merged_action_labels_f)

In [2]:
# dataset split
dataset_type = "training"
merged_action_labels_path = f"/robin-west/womd_processed/vbd/{dataset_type}/action_labels.pkl"
with open(merged_action_labels_path, "rb") as merged_action_labels_f:
    merged_action_labels = pickle.load(merged_action_labels_f)

# data_dir = f"/robin-west/womd_processed/vbd/{dataset_type}/processed/*.pkl"
# data_file_list = glob.glob(data_dir)
# existing_scenario_list = [data_file_path.split("/")[-1].rstrip(".pkl").split("_")[-1] for data_file_path in data_file_list]

In [3]:
from collections import defaultdict
action_to_scenario_id = defaultdict(list)
for scenario_id in merged_action_labels.keys():
    action_label = merged_action_labels[scenario_id]["4s_action"]
    ats_key = (action_label[0,0], action_label[0,1])
    action_to_scenario_id[ats_key].append(scenario_id)

In [4]:
for ats_key in sorted(list(action_to_scenario_id.keys())):
    print(f"{ats_key}:{len(action_to_scenario_id[ats_key])}")

(1, 0):31660
(1, 1):6164
(1, 2):5906
(1, 3):1027
(2, 0):41539
(2, 1):506
(2, 2):1615
(2, 3):766
(3, 0):344286
(3, 1):22549
(3, 2):24539
(3, 3):6438


In [5]:
num_scenarios_per_action_label = 3000

import random
def shuffle_in_place(my_list):
    random.shuffle(my_list)  # Shuffles the list in place
    return my_list

# get a list of of metadata
action_to_scenario_id_subset = {
    ats_key: shuffle_in_place(action_to_scenario_id[ats_key])[:num_scenarios_per_action_label] 
    for ats_key in action_to_scenario_id.keys()
}
for ats_key in sorted(list(action_to_scenario_id_subset.keys())):
    print(f"{ats_key}:{len(action_to_scenario_id_subset[ats_key])}")

(1, 0):3000
(1, 1):3000
(1, 2):3000
(1, 3):1027
(2, 0):3000
(2, 1):506
(2, 2):1615
(2, 3):766
(3, 0):3000
(3, 1):3000
(3, 2):3000
(3, 3):3000


In [6]:
scenario_id_subset_list = [scenario_id for ats_key in action_to_scenario_id_subset.keys() for scenario_id in action_to_scenario_id_subset[ats_key]]
print(len(scenario_id_subset_list))

# check duplicates
from collections import defaultdict

def get_duplicate_indices(my_list):
    positions = defaultdict(list)
    
    # Store the indices of each element
    for index, element in enumerate(my_list):
        positions[element].append(index)
    
    # Keep only those with more than one occurrence
    duplicates_with_indices = {item: idxs for item, idxs in positions.items() if len(idxs) > 1}
    return duplicates_with_indices
dup_indices = get_duplicate_indices(scenario_id_subset_list)
print(dup_indices) # should be empty

27914
{}


In [7]:
merged_action_labels_subset = {
    scenario_id: merged_action_labels[scenario_id]
    for scenario_id in scenario_id_subset_list
}

In [8]:
# move file 
import shutil
import os

def safe_copy(src, dst):
    if os.path.exists(dst):
        print(f"Destination {dst} already exists. Aborting.")
        return
    shutil.copy2(src, dst)
    # print(f"Moved {src} to {dst}")

missed_file = defaultdict(list)
for scenario_id in tqdm(scenario_id_subset_list):
    try:
        src = f"/robin-west/womd_processed/vbd/{dataset_type}/processed/scenario_{scenario_id}.pkl"
        dst = f"/robin-west/womd_processed/single_agent_subset/{dataset_type}/processed/scenario_{scenario_id}.pkl"
        safe_copy(src, dst)
    except FileNotFoundError:
        action_label = merged_action_labels[scenario_id]["4s_action"]
        ats_key = (action_label[0,0], action_label[0,1])
        missed_file[ats_key].append(scenario_id)

100%|██████████| 27914/27914 [1:18:31<00:00,  5.92it/s]  


In [9]:
for k in missed_file.keys():
    print(k, len(missed_file[k]))

(3, 0) 25


In [1]:
import pickle

In [2]:
dataset_type = 'validation'
# with open(f"/robin-west/womd_processed/vbd/{dataset_type}/action_to_scenario_id_subset.pkl", "rb") as action_to_scenario_id_subset_f:
#         action_to_scenario_id_subset = pickle.load(action_to_scenario_id_subset_f)
# scenario_id_subset_list = [scenario_id for ats_key in action_to_scenario_id_subset.keys() for scenario_id in action_to_scenario_id_subset[ats_key]]
    

In [4]:
for key in action_to_scenario_id_subset.keys():
    print(key, action_to_scenario_id_subset[key][:])

(3, 0) ['53dd6e908b017adf', 'b538157f8bc536e6', 'ba15809220e841a5', '59482d571cdd56e2', '6c350c45eaa0f5e1', 'ebd9718392aa8ef6', '40526c290d28ddd8', '7e43613ec6e9d36b', '3ca97b467286ae14', '85a4abc7491e8eb2', 'f711753cfc1788df', 'da80dfd334a2283d', 'c83a2d87b438769d', 'ac6dc1d6e8cb0ffc', '5e2b400c8f29bfa', '64255ac9516e8838', 'a9a891bcf1e54bc7', '7ae1f5c2c37aa1e9', 'da79c2fc67d45710', '632c02706f2e91dd', 'b551bddcc9c0e54e', '4ab41134d1841c11', '70130619f775ada3', '3c2a7ad51d906c11', 'b3ca0f35581e3779', 'f478b596dfb94941', 'b5b288cc9ca44157', 'd288f2a23215f977', '70594ba047900394', '4166583f6c066f3c', '6f840d3b60a7d7ee', '7946942b989b92c4', '6179e7a2c28597d4', 'e009673aa26420ae', 'a060090612579cab', 'f14a1f4dc64c9b06', '98e6381307b5eac8', 'a5d572dc75b24afc', '7bd4226fa00e331c', 'ea82f56e5e778796', '25c3819d77b0aef2', '1e9380eed97e68fe', 'dfa03fa2ccbebb44', '3cb9941148b8d2cc', 'c5f2774053460b70', '9cd57f9b31aba013', '29225baa2932df54', '43b955b744dd721a', '55adb841cd6bc7c0', '3c918e780533

In [3]:
action_labels_dst = f"/robin-west/womd_processed/vbd/{dataset_type}/action_labels_old.pkl"
with open(action_labels_dst, 'rb') as action_labels_f:
    action_labels = pickle.load(action_labels_f)

In [5]:
scenario_id = 'fcacbba048f80da1'
action_labels[scenario_id]

{'sdc_id': 1,
 '1s_action': array([[2, 0],
        [2, 0],
        [3, 1],
        [3, 1],
        [3, 0],
        [3, 1],
        [1, 1],
        [3, 1]]),
 '4s_action': array([[2, 1],
        [3, 1]])}

In [6]:
# # 8s_action parsing from 1s_action 

# # select a partial dataset to check 
# import random

# scenario_id_list = list(action_labels.keys())
# selected_id_list = random.choices(scenario_id_list, k=30)

In [4]:
import numpy as np
def action_1s_to_8s(action_1s, steer_thres = 2, speed_thres = 2):
    steer_action = action_1s[:, 1]
    speed_action = action_1s[:, 0]

    steer_dict = {}
    speed_dict = {}

    for steer in steer_action:
        steer_dict[steer] = steer_dict.get(steer, 0) + 1

    for speed in speed_action:
        speed_dict[speed] = speed_dict.get(speed, 0) + 1

    # parse steer
    steer_8s = 0
    steer_label_cnt = -1
    for key in [1, 2, 3]:
        if key != 0 and steer_dict.get(key, 0)>=steer_thres and steer_dict.get(key, 0) > steer_label_cnt:
            steer_8s = key
            steer_label_cnt = steer_dict.get(key, 0)


    # parse speed
    speed_8s = 3
    speed_label_cnt = -1
    for key in [1, 2]:
        if key != 3 and speed_dict.get(key,0)>=speed_thres and speed_dict.get(key,0) > speed_label_cnt:
            speed_8s = key
            speed_label_cnt = speed_dict.get(key,0)
    return np.array([speed_8s, steer_8s])



def parse_action_1s_to_8s(scenario_id_list, action_labels, steer_thres = 2, speed_thres = 2):
    action_labels_8s_dict = {}
    for scenario_id in scenario_id_list:
        action_labels_1s = action_labels[scenario_id]['1s_action']
        action_labels_8s = action_1s_to_8s(action_labels_1s, steer_thres = steer_thres, speed_thres = speed_thres)
        action_labels_4s = action_1s_to_4s(action_labels_1s, steer_thres = steer_thres, speed_thres = speed_thres)
        action_labels_8s_dict[scenario_id] = action_labels_8s_dict.get(scenario_id, {})
        action_labels_8s_dict[scenario_id]['8s_action'] = action_labels_8s
        action_labels_8s_dict[scenario_id]['4s_action'] = action_labels_4s
        action_labels_8s_dict[scenario_id]['1s_action'] = action_labels_1s
        action_labels_8s_dict[scenario_id]['4s_action_old'] = action_labels[scenario_id]['4s_action']
        action_labels_8s_dict[scenario_id]['sdc_id'] = action_labels[scenario_id]['sdc_id']
    return action_labels_8s_dict


def action_1s_to_4s(action_1s, steer_thres = 2, speed_thres = 2):
    labels_4s = np.zeros((2,2))
    for i in range(2):
        steer_action = action_1s[4*i:4*(i+1), 1]
        speed_action = action_1s[4*i:4*(i+1), 0]

        steer_dict = {}
        speed_dict = {}

        for steer in steer_action:
            steer_dict[steer] = steer_dict.get(steer, 0) + 1

        for speed in speed_action:
            speed_dict[speed] = speed_dict.get(speed, 0) + 1

        # parse steer
        steer_4s = 0
        steer_label_cnt = -1

        for key in [1, 2, 3]:
            if key != 0 and steer_dict.get(key, 0) >= steer_thres and steer_dict.get(key, 0) >= steer_label_cnt:
                steer_4s = key
                steer_label_cnt = steer_dict.get(key, 0)


        # parse speed
        speed_4s = 3
        speed_label_cnt = -1
        for key in [1,2]:
            if key != 3 and speed_dict.get(key,0) >= speed_thres and speed_dict.get(key,0) >= speed_label_cnt:
                speed_4s = key
                speed_label_cnt = speed_dict[key]
        labels_4s[i,0] = speed_4s
        labels_4s[i,1] = steer_4s
    return labels_4s






In [5]:
action_labels_8s_dict = parse_action_1s_to_8s(list(action_labels.keys()), action_labels, steer_thres = 2, speed_thres = 2)

In [6]:
len(action_labels_8s_dict.keys())

44097

In [24]:
for scenario_id in action_labels_8s_dict.keys():
    print(f"{scenario_id}\n{action_labels_8s_dict[scenario_id]['8s_action']}")
    print(f"4s\n{action_labels_8s_dict[scenario_id]['4s_action']}\n{action_labels_8s_dict[scenario_id]['4s_action_old']}")
    print(f"1s\n{action_labels_8s_dict[scenario_id]['1s_action']}")

53dd6e908b017adf
[3 0]
4s
[[3. 0.]
 [3. 0.]]
[[3 0]
 [3 0]]
1s
[[3 0]
 [3 0]
 [3 0]
 [3 0]
 [3 0]
 [3 0]
 [3 0]
 [3 0]]
b538157f8bc536e6
[3 0]
4s
[[3. 0.]
 [3. 0.]]
[[3 0]
 [3 0]]
1s
[[3 0]
 [3 0]
 [3 0]
 [3 0]
 [3 0]
 [3 0]
 [3 0]
 [3 0]]
ba15809220e841a5
[3 0]
4s
[[3. 0.]
 [3. 0.]]
[[3 0]
 [3 0]]
1s
[[2 0]
 [3 0]
 [3 0]
 [3 0]
 [3 0]
 [3 0]
 [3 0]
 [3 0]]
59482d571cdd56e2
[3 0]
4s
[[3. 0.]
 [3. 0.]]
[[3 0]
 [3 0]]
1s
[[3 0]
 [3 0]
 [3 0]
 [3 0]
 [3 0]
 [3 0]
 [3 0]
 [3 0]]
6c350c45eaa0f5e1
[3 1]
4s
[[3. 0.]
 [3. 1.]]
[[3 0]
 [3 1]]
1s
[[1 0]
 [3 0]
 [3 0]
 [3 0]
 [3 0]
 [3 1]
 [3 1]
 [3 1]]
ebd9718392aa8ef6
[3 0]
4s
[[3. 0.]
 [3. 0.]]
[[3 0]
 [3 0]]
1s
[[3 0]
 [3 0]
 [3 0]
 [3 0]
 [3 0]
 [3 0]
 [3 0]
 [1 0]]
40526c290d28ddd8
[1 0]
4s
[[1. 0.]
 [3. 0.]]
[[3 0]
 [3 0]]
1s
[[1 0]
 [1 0]
 [3 0]
 [3 0]
 [3 0]
 [2 0]
 [3 0]
 [3 0]]
7e43613ec6e9d36b
[3 0]
4s
[[3. 0.]
 [3. 0.]]
[[3 0]
 [3 0]]
1s
[[3 0]
 [3 0]
 [3 0]
 [3 0]
 [3 0]
 [3 0]
 [3 0]
 [3 0]]
3ca97b467286ae14
[3 0]
4s
[[3. 0.]
 [3. 

In [7]:
with open(f"/robin-west/womd_processed/vbd/{dataset_type}/action_labels.pkl", "wb") as action_labels_8s_f:
    pickle.dump(action_labels_8s_dict, action_labels_8s_f)

In [24]:
import torch
def wrap_angle(angle):
    """
    Wrap the angle to [-pi, pi].

    Args:
        angle (torch.Tensor): Angle tensor.

    Returns:
        torch.Tensor: Wrapped angle.

    """
    # return torch.atan2(torch.sin(angle), torch.cos(angle))
    return (angle + torch.pi) % (2 * torch.pi) - torch.pi

In [26]:
angle1 = -2*torch.pi
angle2 = 2*torch.pi + 1e-2
heading_diff = wrap_angle(angle1-angle2)


In [27]:
heading_diff

-0.009999999999999787

In [1]:
import torch
import yaml
import datetime
import argparse
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

# set tf to cpu only
import tensorflow as tf
tf.config.set_visible_devices([], "GPU")
import jax
jax.config.update("jax_platform_name", "cpu")

import sys
sys.path.append("/robin-west/VBD")

from vbd.data.dataset import WaymaxDataset
from vbd.model.VBD import VBD
from torch.utils.data import DataLoader

import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import WandbLogger, CSVLogger
from lightning.pytorch.strategies import DDPStrategy

from matplotlib import pyplot as plt

In [2]:
def load_config(file_path):
    with open(file_path, "r") as file:
        data = yaml.safe_load(file)
    return data

In [3]:
config_path = "/robin-west/VBD/config/_final_validate/vbd_ego_agent_future_len_40_input_action_normalize_true_prior_means_steer_and_speed_scale_15_no_cond_attn_ego_validate.yaml"
cfg = load_config(config_path)
cfg['num_workers'] = 1
cfg['batch_size'] = 1
dataset_dir = '/root/single_agent_subset/validation/processed'

In [4]:
# create dataset
from vbd.data.dataset import WaymaxTestDataset

val_dataset = WaymaxTestDataset(
    data_dir=dataset_dir,
    future_len = cfg["future_len"],
    anchor_path=cfg["anchor_path"],
    predict_ego_only=cfg["predict_ego_only"],
    action_labels_path=cfg["validation_action_labels_path"],
    max_object= cfg["agents_len"],
)

In [15]:
import random
random_labels = {}

for data_dict in val_dataset:
    scenario_id = data_dict['scenario_id']
    steer_label = data_dict['sdc_steer_label']
    speed_label = data_dict['sdc_speed_label']

    steer_list = [0,1,2,3]
    speed_list = [1,2,3]
    
    all_label_combo = [(speed, steer) for speed in speed_list for steer in steer_list]
    all_label_combo.remove((speed_label, steer_label))
    all_label_combo.remove((2,3))

    random_label = random.choice(all_label_combo)
    random_labels[scenario_id] = {
        'gt_steer': steer_label,
        'gt_speed': speed_label,
        'random_steer': random_label[1],
        'random_speed': random_label[0],
    }    

In [16]:
import pickle
with open('/robin-west/VBD/config/_table_2/random_labels.pkl', 'wb') as random_labels_f:
    pickle.dump(random_labels, random_labels_f)

In [17]:
random_labels

{'53dd6e908b017adf': {'gt_steer': 0,
  'gt_speed': 3,
  'random_steer': 1,
  'random_speed': 2},
 'b538157f8bc536e6': {'gt_steer': 0,
  'gt_speed': 3,
  'random_steer': 1,
  'random_speed': 1},
 'ba15809220e841a5': {'gt_steer': 0,
  'gt_speed': 3,
  'random_steer': 0,
  'random_speed': 2},
 '59482d571cdd56e2': {'gt_steer': 0,
  'gt_speed': 3,
  'random_steer': 3,
  'random_speed': 1},
 '6c350c45eaa0f5e1': {'gt_steer': 0,
  'gt_speed': 3,
  'random_steer': 3,
  'random_speed': 3},
 'ebd9718392aa8ef6': {'gt_steer': 0,
  'gt_speed': 3,
  'random_steer': 2,
  'random_speed': 1},
 '40526c290d28ddd8': {'gt_steer': 0,
  'gt_speed': 3,
  'random_steer': 0,
  'random_speed': 2},
 '7e43613ec6e9d36b': {'gt_steer': 0,
  'gt_speed': 3,
  'random_steer': 1,
  'random_speed': 2},
 '3ca97b467286ae14': {'gt_steer': 0,
  'gt_speed': 3,
  'random_steer': 2,
  'random_speed': 1},
 '85a4abc7491e8eb2': {'gt_steer': 0,
  'gt_speed': 3,
  'random_steer': 3,
  'random_speed': 1},
 'f711753cfc1788df': {'gt_stee