# Creation of collision rewind dataset

In [1]:
import os
from tqdm import tqdm
import gymnasium
import highway_env
import numpy as np
from stable_baselines3 import DQN
from multiprocessing import Pool, cpu_count

lanes_cnt_5_cfg = {
    "lanes_count": 5
}

In [2]:
env_id = "highway-fast-v0"
model_path = "/u/shuhan/projects/vla/data/highway_env/highway_fast_v0_dqn_meta_action_5_lanes/model"


In [274]:
import copy

env = gymnasium.make("highway-fast-v0", render_mode='rgb_array', config=lanes_cnt_5_cfg)

# Load the trained model
model = DQN.load(model_path, device='cpu')

observations = []
actions = []

rollout_length = 30

np.random.seed(42)

obs, _ = env.reset()

rewind_envs = []


for i in range(rollout_length):
    action, _states = model.predict(obs, deterministic=True)
    
    rewind_envs.append(copy.deepcopy(env))
    observations.append(obs)
    actions.append(action)
    
    obs, reward, rl_collision, truncated, info = env.step(action)
    
    if rl_collision:
        print("collision")
        break

    if truncated:
        print("truncated")
        break

# remove the last step if there is a collision
# this is because we do not have the "ground truth" action for the last step
if rl_collision:
    rewind_envs = rewind_envs[:-1]
    observations = observations[:-1]
    actions = actions[:-1]

env.close()


truncated


In [288]:
saved_file = '/storage/Datasets/highway_env/highway_fast_v0_dqn_meta_action_5_lanes/rollouts_train_collision/rollout_0.npz'

data = np.load(open(saved_file, 'rb'))
data.keys()

KeysView(NpzFile 'object' with keys: observations, actions, collision_rewind_steps, collision_observations, collision_actions)

In [289]:
data['collision_rewind_steps']

array([ 1,  2,  2, 27, 28])

In [290]:
data['collision_actions']

array([2, 0, 2, 0, 0])

In [282]:
import tqdm

all_rewind_steps = len(observations)

collision_rewind_steps = []
collision_observations = []
collision_actions = []

for rewind_step in tqdm.tqdm(range(all_rewind_steps)):
  for action in range(5):
    if action == actions[rewind_step]:
      continue  
    _, _ = env.reset()
    env.__dict__.update(copy.deepcopy(rewind_envs[rewind_step].__dict__))
    
    obs, reward, done, truncated, info = env.step(action)
    if done:
      collision_rewind_steps.append(rewind_step)
      collision_observations.append(obs)
      collision_actions.append(action)
    env.close()

100%|██████████| 30/30 [00:05<00:00,  5.15it/s]


In [283]:
collision_rewind_steps

[9, 14, 14, 14, 14, 15, 15, 15, 15, 16, 18, 21, 22, 26, 26, 27, 29]

In [286]:
np.array(collision_observations)

array([[[ 1.00000000e+00,  1.00000000e+00,  6.34615123e-01,
          1.88461617e-01, -3.66908498e-02],
        [ 1.00000000e+00, -1.30286766e-02, -9.99094918e-02,
         -5.46646379e-02,  3.66908498e-02],
        [ 1.00000000e+00,  1.64383035e-02, -2.34615147e-01,
          7.48892426e-02,  3.66908498e-02],
        [ 1.00000000e+00,  1.52968511e-01, -4.34615135e-01,
          7.41318315e-02,  3.66908498e-02],
        [ 1.00000000e+00,  2.42071480e-01,  1.65384859e-01,
          7.88360387e-02,  3.66908498e-02]],

       [[ 1.00000000e+00,  1.00000000e+00,  8.00000012e-01,
          3.00000012e-01,  0.00000000e+00],
        [ 1.00000000e+00,  2.50000004e-02,  0.00000000e+00,
         -8.65638703e-02,  0.00000000e+00],
        [ 1.00000000e+00,  6.63988143e-02, -2.00000003e-01,
         -5.55987619e-02,  0.00000000e+00],
        [ 1.00000000e+00,  1.77511960e-01, -4.00000006e-01,
         -3.19540761e-02,  0.00000000e+00],
        [ 1.00000000e+00,  3.34855288e-01, -6.00000024e-01,
  

In [285]:
np.array(collision_actions)

array([0, 1, 2, 3, 4, 1, 2, 3, 4, 2, 0, 0, 0, 0, 2, 0, 0])

In [105]:
for i in range(rewind_step, len(observations)):
  print(observations[i][0, :])

[ 1.000000e+00  1.000000e+00  6.000015e-01  3.750000e-01 -5.741917e-07]
[1.         1.         0.7851608  0.37477875 0.01287973]
[1.0000000e+00 1.0000000e+00 7.9982656e-01 3.7499994e-01 1.9848753e-04]
[ 1.0000000e+00  1.0000000e+00  8.0000144e-01  3.7500000e-01
 -5.0904777e-07]
[ 1.000000e+00  1.000000e+00  8.000001e-01  3.750000e-01 -6.304766e-08]


In [163]:

import matplotlib.animation as animation
from IPython.display import HTML
import tqdm
import gymnasium
import highway_env
from matplotlib import pyplot as plt
%matplotlib inline

env = gymnasium.make('highway-fast-v0', render_mode='rgb_array',
                     config={"lanes_count": 5})
env.reset()


frames = []

for i in range(100):
    action = np.random.randint(0, 5)
    observations.append(obs)
    actions.append(action)
    rewind_envs.append(copy.deepcopy(env))
    
    obs, reward, done, truncated, info = env.step(action)
    
    frames.append(env.render())
  
    if done == True:
        print("collision")
        break
  
    if truncated == True:
        print("truncated")
        break


fig, ax = plt.subplots()
ani = animation.ArtistAnimation(fig, [[ax.imshow(frame)] for frame in frames], interval=200, blit=True, repeat_delay=1000)
plt.close(fig)  # Prevent the static image from displaying
ani.save('environment_steps.gif', writer='pillow')

HTML(ani.to_jshtml())

collision


In [164]:
observations[-2]

array([[ 1.00000000e+00,  1.00000000e+00,  8.00000012e-01,
         3.74999970e-01,  0.00000000e+00],
       [ 1.00000000e+00, -1.35655021e-02, -6.00000024e-01,
        -1.11171626e-01,  0.00000000e+00],
       [ 1.00000000e+00,  2.84598358e-02, -8.00000012e-01,
        -1.28287390e-01,  0.00000000e+00],
       [ 1.00000000e+00,  1.14917599e-01, -1.99982673e-01,
        -1.29980654e-01, -1.24809849e-05],
       [ 1.00000000e+00,  3.05102468e-01, -8.00000012e-01,
        -1.02223828e-01,  0.00000000e+00]], dtype=float32)

In [165]:
observations[-1]

array([[ 1.00000000e+00,  1.00000000e+00,  8.00000012e-01,
         3.75000000e-01,  0.00000000e+00],
       [ 1.00000000e+00, -2.25737542e-02, -8.00000012e-01,
        -1.26561195e-01,  0.00000000e+00],
       [ 1.00000000e+00,  6.31088838e-02, -1.99998885e-01,
        -1.28871754e-01, -8.07550805e-07],
       [ 1.00000000e+00,  2.64100343e-01, -8.00000012e-01,
        -1.02909751e-01,  0.00000000e+00],
       [ 1.00000000e+00,  3.17254305e-01, -6.00000024e-01,
        -1.01854384e-01, -8.70270522e-11]], dtype=float32)

# Loading collision rewind dataset

In [360]:
%load_ext autoreload
%autoreload 2
import sys
import torch
from torch.utils.data import DataLoader
sys.path.append('/u/shuhan/projects/vla')

from src.environments.highway_env.dataset import HighwayCollisionDataset, collate_fn_collision

dataset = HighwayCollisionDataset(data_dir='/storage/Datasets/highway_env/highway_fast_v0_dqn_meta_action_5_lanes/rollouts_train_collision')

# define the dataloader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn_collision)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [361]:
len(dataset)

181440

In [355]:
%load_ext autoreload
%autoreload 2

for batch in dataloader:
  observations, actions, valid_mask, collision_rewind_steps, collision_observations, collision_actions, collision_valid_mask = batch
  break

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [356]:
collision_rewind_steps.shape

torch.Size([32, 20])

In [357]:
collision_observations.shape

torch.Size([32, 20, 5, 5])

In [358]:
collision_actions.shape

torch.Size([32, 20])

In [359]:
for i in range(len(collision_valid_mask)):
  print(collision_observations[i][collision_valid_mask[i]].shape)
  print(collision_actions[i][collision_valid_mask[i]].shape)

torch.Size([13, 5, 5])
torch.Size([13])
torch.Size([14, 5, 5])
torch.Size([14])
torch.Size([3, 5, 5])
torch.Size([3])
torch.Size([15, 5, 5])
torch.Size([15])
torch.Size([9, 5, 5])
torch.Size([9])
torch.Size([6, 5, 5])
torch.Size([6])
torch.Size([11, 5, 5])
torch.Size([11])
torch.Size([14, 5, 5])
torch.Size([14])
torch.Size([11, 5, 5])
torch.Size([11])
torch.Size([7, 5, 5])
torch.Size([7])
torch.Size([7, 5, 5])
torch.Size([7])
torch.Size([1, 5, 5])
torch.Size([1])
torch.Size([10, 5, 5])
torch.Size([10])
torch.Size([7, 5, 5])
torch.Size([7])
torch.Size([16, 5, 5])
torch.Size([16])
torch.Size([1, 5, 5])
torch.Size([1])
torch.Size([13, 5, 5])
torch.Size([13])
torch.Size([10, 5, 5])
torch.Size([10])
torch.Size([5, 5, 5])
torch.Size([5])
torch.Size([15, 5, 5])
torch.Size([15])
torch.Size([15, 5, 5])
torch.Size([15])
torch.Size([17, 5, 5])
torch.Size([17])
torch.Size([10, 5, 5])
torch.Size([10])
torch.Size([12, 5, 5])
torch.Size([12])
torch.Size([5, 5, 5])
torch.Size([5])
torch.Size([0, 5, 5]