In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
# disable GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import tensorflow as tf
# set tf to cpu only
tf.config.set_visible_devices([], 'GPU')
import jax
jax.config.update('jax_platform_name', 'cpu')

import glob
import argparse
import pickle
from vbd.data.data_utils import *
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map  # or thread_map

from waymax import dataloader
from waymax.config import DataFormat
import functools

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset_type = "training"
data_dir = f"/robin-west/womd_processed/vbd/{dataset_type}/processed/*.pkl"
data_file_list = glob.glob(data_dir)

In [3]:
def wrap_angle(angle):
    """
    Wrap the angle to [-pi, pi].

    Args:
        angle (torch.Tensor): Angle tensor.

    Returns:
        torch.Tensor: Wrapped angle.

    """
    # return torch.atan2(torch.sin(angle), torch.cos(angle))
    return (angle + torch.pi) % (2 * torch.pi) - torch.pi

    
def extract_high_level_motion_action(heading, acceleration):
    if acceleration > 1: 
        speed_action = 1 # acceleration
    elif acceleration < -1:
        speed_action = 2 # deceleration
    else:
        speed_action = 3 # keep speed
    
    heading = np.rad2deg(heading)
    if np.abs(heading) < 2.4:
        steering_action = 0 # go straight
    elif np.abs(heading) < 26.4: 
        if heading > 0 :
            steering_action = 1 # turn left
        else: 
            steering_action = 2 # turn right
    else:
        if heading > 0 :
            steering_action = 3 # left u turn
        else: 
            steering_action = 2 # turn right
    
    return np.array([speed_action, steering_action])


def extract_patch_action(speed_patch, heading_patch):
    ## no need for sdc
    # first_valid_ts = -1
    # last_valid_ts = -1
    # for ts in range(valid_patch.shape[0]):
    #     if first_valid_ts==-1 and valid_patch[ts]:
    #         first_valid_ts = ts
    #     elif first_valid_ts!=-1 and last_valid_ts==-1:
    #         if not valid_patch[ts]:
    #             last_valid_ts = ts - 1
    #         elif ts == valid_patch.shape[0]-1:
    #             last_valid_ts = ts
    # if first_valid_ts==-1 and last_valid_ts==-1 or first_valid_ts==last_valid_ts:
    #     return np.array([-1, -1], dtype=np.float32)
    assert len(speed_patch) == len(heading_patch)
    speed_diff = 10 * (speed_patch[-1] - speed_patch[0]) / len(speed_patch)
    heading_diff = 10 * (heading_patch[-1] - heading_patch[0]) / len(heading_patch)
    patch_action = extract_high_level_motion_action(heading_diff, speed_diff)
    return patch_action


def extract_patches_action(speed, heading, sample_rate=10):
    high_level_action = []
    for patch_id in range((speed.shape[0]) // sample_rate):
        speed_patch = speed[patch_id*sample_rate:(patch_id+1)*sample_rate]
        heading_patch = heading[patch_id*sample_rate:(patch_id+1)*sample_rate]
        high_level_action.append(extract_patch_action(speed_patch, heading_patch))
    return np.stack(high_level_action, axis=0)

In [4]:
def extract_sdc_action(data):
    scenario = data['scenario_raw']
    sdc_id = np.where(scenario.object_metadata.is_sdc)[0][0]
    sdc_id_in_processed = np.where(data["agents_id"]==sdc_id)[0][0]
    sdc_future = data["agents_future"][sdc_id_in_processed]
    assert sdc_future.shape[0] == 81 and sdc_future.shape[1] == 5, "sdc future traj shape is wrong"
    vel_xy = sdc_future[:, 3:]
    speed = np.linalg.norm(vel_xy, axis=-1)
    heading = sdc_future[:, 2]
    sdc_future_actions_4s = extract_patches_action(speed, heading, sample_rate=40)
    sdc_future_actions_1s = extract_patches_action(speed, heading, sample_rate=10)
    return sdc_id, sdc_future_actions_4s, sdc_future_actions_1s

In [5]:
action_labels = dict() 
for data_file_path in tqdm(data_file_list):
    scenario_id = data_file_path.split("/")[-1].rstrip(".pkl").split("_")[-1]
    with open(data_file_path, "rb") as data_f:
        data = pickle.load(data_f)
    sdc_id, sdc_future_actions_4s, sdc_future_actions_1s = extract_sdc_action(data)
    action_labels[scenario_id] = {
        'sdc_id': sdc_id,
        '1s_action': sdc_future_actions_1s,
        '4s_action': sdc_future_actions_4s,
    }
with open(f"/robin-west/womd_processed/vbd/{dataset_type}/action_labels.pkl", "wb") as action_labels_f:
    pickle.dump(action_labels, action_labels_f)

  0%|          | 0/486995 [00:00<?, ?it/s]An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
  0%|          | 126/486995 [00:29<31:43:04,  4.26it/s]


KeyboardInterrupt: 

In [3]:
len(data_file_list)

486995

In [4]:
num_buckets = 100
buckets = {i: [] for i in range(num_buckets)}

for idx, file_path in enumerate(data_file_list):
    bucket_index = idx % num_buckets
    buckets[bucket_index].append(file_path)

In [5]:
for k in buckets.keys():
    print(k)
    print(len(buckets[k]))

0
4870
1
4870
2
4870
3
4870
4
4870
5
4870
6
4870
7
4870
8
4870
9
4870
10
4870
11
4870
12
4870
13
4870
14
4870
15
4870
16
4870
17
4870
18
4870
19
4870
20
4870
21
4870
22
4870
23
4870
24
4870
25
4870
26
4870
27
4870
28
4870
29
4870
30
4870
31
4870
32
4870
33
4870
34
4870
35
4870
36
4870
37
4870
38
4870
39
4870
40
4870
41
4870
42
4870
43
4870
44
4870
45
4870
46
4870
47
4870
48
4870
49
4870
50
4870
51
4870
52
4870
53
4870
54
4870
55
4870
56
4870
57
4870
58
4870
59
4870
60
4870
61
4870
62
4870
63
4870
64
4870
65
4870
66
4870
67
4870
68
4870
69
4870
70
4870
71
4870
72
4870
73
4870
74
4870
75
4870
76
4870
77
4870
78
4870
79
4870
80
4870
81
4870
82
4870
83
4870
84
4870
85
4870
86
4870
87
4870
88
4870
89
4870
90
4870
91
4870
92
4870
93
4870
94
4870
95
4869
96
4869
97
4869
98
4869
99
4869


In [18]:
with open("/robin-west/womd_processed/vbd/validation/action_labels/action_label_arxiv/action_labels.pkl", 'rb') as f:
    data = pickle.load(f)

In [20]:
print(len(data))

1


In [6]:
with open(f"/robin-west/VBD/script/{dataset_type}_file_buckets.pkl", "wb") as buckets_f:
    pickle.dump(buckets, buckets_f)

In [10]:
# merge buckets
dataset_type = "training"
bucket_dir = f"/robin-west/womd_processed/vbd/{dataset_type}/action_labels/action_labels_bucket_*.pkl"
bucket_file_list = glob.glob(bucket_dir)
merged_action_labels_path = f"/robin-west/womd_processed/vbd/{dataset_type}/action_labels.pkl"

In [11]:
print(len(bucket_file_list))

100


In [13]:
merged_action_labels = dict()
for bucket_file_path in bucket_file_list:
    with open(bucket_file_path, "rb") as bucket_f:
        bucket_action_labels = pickle.load(bucket_f)
    merged_action_labels.update(bucket_action_labels)

In [14]:
with open(merged_action_labels_path, "wb") as merged_action_labels_f:
    pickle.dump(merged_action_labels, merged_action_labels_f)

In [15]:
len(merged_action_labels.values())

486995

In [21]:
# dataset split
dataset_type = "training"
merged_action_labels_path = f"/robin-west/womd_processed/vbd/{dataset_type}/action_labels.pkl"
with open(merged_action_labels_path, "rb") as merged_action_labels_f:
    merged_action_labels = pickle.load(merged_action_labels_f)

# data_dir = f"/robin-west/womd_processed/vbd/{dataset_type}/processed/*.pkl"
# data_file_list = glob.glob(data_dir)
# existing_scenario_list = [data_file_path.split("/")[-1].rstrip(".pkl").split("_")[-1] for data_file_path in data_file_list]

In [22]:
from collections import defaultdict
action_to_scenario_id = defaultdict(list)
for scenario_id in merged_action_labels.keys():
    action_label = merged_action_labels[scenario_id]["action_label"]
    action_to_scenario_id[action_label].append(scenario_id)

In [23]:
for ats_key in sorted(list(action_to_scenario_id.keys())):
    print(f"{ats_key}:{len(action_to_scenario_id[ats_key])}")

0:127324
1:255146
2:12404
3:11920
4:67
5:23429
6:1869
7:54836


In [12]:
num_scenarios_per_action_label = 300

import random
def shuffle_in_place(my_list):
    random.shuffle(my_list)  # Shuffles the list in place
    return my_list

# get a list of of metadata
action_to_scenario_id_subset = {
    ats_key: shuffle_in_place(action_to_scenario_id[ats_key])[:num_scenarios_per_action_label] 
    for ats_key in action_to_scenario_id.keys()
}
for ats_key in sorted(list(action_to_scenario_id_subset.keys())):
    print(f"{ats_key}:{len(action_to_scenario_id_subset[ats_key])}")


(1, 0):300
(1, 1):300
(1, 2):300
(1, 3):7
(2, 0):300
(2, 1):54
(2, 2):56
(3, 0):300
(3, 1):300
(3, 2):300
(3, 3):92


In [13]:
scenario_id_subset_list = [scenario_id for ats_key in action_to_scenario_id_subset.keys() for scenario_id in action_to_scenario_id_subset[ats_key]]
print(len(scenario_id_subset_list))

# check duplicates
from collections import defaultdict

def get_duplicate_indices(my_list):
    positions = defaultdict(list)
    
    # Store the indices of each element
    for index, element in enumerate(my_list):
        positions[element].append(index)
    
    # Keep only those with more than one occurrence
    duplicates_with_indices = {item: idxs for item, idxs in positions.items() if len(idxs) > 1}
    return duplicates_with_indices
dup_indices = get_duplicate_indices(scenario_id_subset_list)
print(dup_indices) # should be empty

2309
{}


In [14]:
subset_list_path = f"/robin-west/womd_processed/vbd/{dataset_type}/action_to_scenario_id_subset.pkl"
with open(subset_list_path, "wb") as subset_list_f:
    pickle.dump(action_to_scenario_id_subset, subset_list_f)

In [15]:
merged_action_labels_subset = {
    scenario_id: merged_action_labels[scenario_id]
    for scenario_id in scenario_id_subset_list
}
subset_action_labels_path = f"/robin-west/womd_processed/vbd/{dataset_type}/subset_action_labels.pkl"
with open(subset_action_labels_path, "wb") as subset_action_labels_f:
    pickle.dump(merged_action_labels_subset, subset_action_labels_f)

In [15]:
# move file 
import shutil
import os

def safe_copy(src, dst):
    if os.path.exists(dst):
        print(f"Destination {dst} already exists. Aborting.")
        return
    shutil.copy2(src, dst)
    # print(f"Moved {src} to {dst}")

missed_file = defaultdict(list)
for scenario_id in tqdm(scenario_id_subset_list):
    try:
        src = f"/robin-west/womd_processed/vbd/{dataset_type}/processed/scenario_{scenario_id}.pkl"
        dst = f"/robin-west/womd_processed/single_agent_subset_v2/{dataset_type}/processed/scenario_{scenario_id}.pkl"
        safe_copy(src, dst)
    except FileNotFoundError:
        action_label = merged_action_labels[scenario_id]["4s_action"]
        ats_key = (action_label[0,0], action_label[0,1])
        missed_file[ats_key].append(scenario_id)

  0%|          | 7/23714 [00:33<31:58:45,  4.86s/it]


KeyboardInterrupt: 

In [9]:
for k in missed_file.keys():
    print(k, len(missed_file[k]))

(3, 0) 25


In [9]:
with open(f"/robin-west/womd_processed/single_agent_subset/{dataset_type}/processed/action_labels.pkl", "wb") as f:
    pickle.dump(merged_action_labels_subset, f)

In [2]:
# investigate action labels
# 1a0bd6424027a059
import pickle
with open("/robin-west/womd_processed/single_agent_subset/validation/action_labels.pkl", "rb") as action_labels_f:
    action_labels = pickle.load(action_labels_f)

from collections import defaultdict
action_to_scenario_id = defaultdict(list)
for scenario_id in action_labels.keys():
    action_label = action_labels[scenario_id]["4s_action"]
    ats_key = (action_label[0,0], action_label[0,1])
    action_to_scenario_id[ats_key].append(scenario_id)

In [4]:
for ats_key in sorted(list(action_to_scenario_id.keys())):
    print(ats_key, action_to_scenario_id[ats_key][1])

(1, 0) 7ea1371b05066892
(1, 1) ec5fd750cfde767c
(1, 2) ca1bb3609957057a
(1, 3) 888f97beca5cca8f
(2, 0) 3bba22c6e511b539
(2, 1) af68115f030eb304
(2, 2) 3a6f9cabab35542f
(2, 3) 38479dfd28740cd0
(3, 0) e569f6d6d9c1d67d
(3, 1) 3f380493b10424df
(3, 2) 876cd6af7c7c749
(3, 3) 86b4b920ba9fb858
