In [1]:
from IPython.display import HTML
import tqdm
import gymnasium
import highway_env
from matplotlib import pyplot as plt
%matplotlib inline

env = gymnasium.make('highway-fast-v0', render_mode='rgb_array')
env.reset()

import matplotlib.animation as animation

frames = []
for _ in tqdm.trange(3):
    action = env.unwrapped.action_type.actions_indexes["FASTER"]
    obs, reward, done, truncated, info = env.step(action)
    frames.append(env.render())

fig, ax = plt.subplots()
ani = animation.ArtistAnimation(fig, [[ax.imshow(frame)] for frame in frames], interval=200, blit=True, repeat_delay=1000)
plt.close(fig)  # Prevent the static image from displaying
ani.save('environment_steps.gif', writer='pillow')

HTML(ani.to_jshtml())

  0%|          | 0/3 [00:00<?, ?it/s]error: XDG_RUNTIME_DIR not set in the environment.
100%|██████████| 3/3 [00:00<00:00, 18.02it/s]


In [2]:
len(env.unwrapped.action_type.actions_indexes)

5

In [2]:
model.predict(obs, deterministic=True)

(array(0), None)

In [28]:
import numpy as np

np.load('/u/shuhan/projects/vla/data/highway_env/highway_fast_v0_dqn_meta_action/rollouts/rollout_0.npz')['observations'][15]

array([[ 1.00000000e+00,  1.00000000e+00,  6.66666687e-01,
         3.75000000e-01,  0.00000000e+00],
       [ 1.00000000e+00, -4.71003652e-02, -3.33333343e-01,
        -1.12765811e-01,  0.00000000e+00],
       [ 1.00000000e+00,  9.82110351e-02, -6.66666687e-01,
        -1.19686946e-01,  0.00000000e+00],
       [ 1.00000000e+00,  2.21382990e-01, -1.11022302e-16,
        -1.18464656e-01,  0.00000000e+00],
       [ 1.00000000e+00,  2.85188526e-01, -3.33333343e-01,
        -1.23373389e-01,  0.00000000e+00]], dtype=float32)

In [11]:
np.load('/u/shuhan/projects/vla/data/highway_env/highway_fast_v0_dqn_meta_action/rollouts/rollout_0.npz')['actions']


array([3, 0, 0, 0, 4, 2, 2, 3, 3, 3, 3, 1, 1, 3, 1, 1, 0, 4, 4, 2, 0, 2,
       0, 2, 0, 2, 0, 2, 0, 2])

In [57]:

# build a dataset using cached observations and actions
import os
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader

class HighwayDataset(Dataset):
    def __init__(self, data_dir: str):
        self.data_dir = data_dir
        self.files = self._obtain_all_files()

    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        file_path = self.files[idx]
        data = np.load(file_path)
        observations = data['observations']
        actions = data['actions']
        return observations, actions
    
    def _obtain_all_files(self):
        return [os.path.join(self.data_dir, f) for f in os.listdir(self.data_dir) if f.endswith('.npz')]

# define the collate function
def collate_fn(batch):
  # Pad sequences to the maximum length in the batch
  max_obs_len = max(obs.shape[0] for obs, _ in batch)
  max_act_len = max(act.shape[0] for _, act in batch)


  # Pad and stack observations
  observations = torch.stack([
      torch.nn.functional.pad(torch.tensor(obs), (0, 0, 0, 0, 0, max_obs_len - obs.shape[0]), value=-100)
      for obs, _ in batch
  ])

  # Pad and stack actions
  actions = torch.stack([
      torch.nn.functional.pad(torch.tensor(act), (0, max_act_len - act.shape[0]), value=-100)
      for _, act in batch
  ])


  valid_mask = torch.ones(len(batch), max_obs_len, dtype=torch.bool)
  for i, (obs, _) in enumerate(batch):
    valid_mask[i, obs.shape[0]:] = False # mark the invalid tokens

  return observations, actions, valid_mask

data_folder = '/storage/Datasets/highway_env/highway_fast_v0_dqn_meta_action/rollouts'

dataset = HighwayDataset(data_folder)

# define the dataloader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)


In [58]:
for batch in dataloader:
  break


In [59]:
batch[0].shape

torch.Size([32, 30, 5, 5])

In [67]:
batch[1][batch[-1]]


tensor([3, 0, 2, 2, 3, 1, 1, 0, 0, 0, 1, 0, 2, 3, 3, 2, 1, 1, 2, 1, 0, 3, 1, 2,
        3, 1, 2, 0, 3, 3, 2, 3, 3, 3, 1, 0, 1, 1, 0, 2, 2, 0, 0, 2, 4, 4, 0, 0,
        0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 4, 4, 3, 0, 1, 2, 1, 1, 2, 1, 1, 0, 4, 4,
        3, 4, 1, 3, 1, 2, 1, 3, 1, 0, 3, 3, 3, 3, 3, 4, 2, 3, 3, 1, 1, 2, 0, 0,
        1, 2, 2, 3, 1, 1, 0, 1, 2, 1, 2, 1, 0, 1, 2, 0, 3, 0, 0, 0, 0, 3, 2, 2,
        3, 3, 1, 1, 1, 1, 2, 2, 1, 1, 4, 4, 4, 2, 2, 0, 2, 3, 3, 1, 1, 2, 1, 1,
        1, 0, 3, 3, 3, 4, 3, 0, 0, 1, 1, 2, 2, 0, 3, 2, 0, 4, 4, 1, 1, 1, 3, 2,
        1, 2, 0, 1, 0, 1, 2, 3, 1, 2, 1, 3, 2, 1, 2, 0, 2, 0, 0, 3, 4, 2, 3, 3,
        1, 1, 1, 1, 1, 4, 1, 0, 0, 1, 1, 1, 0, 2, 3, 2, 1, 1, 3, 1, 1, 0, 1, 3,
        1, 0, 1, 1, 3, 0, 1, 2, 2, 4, 1, 4, 3, 1, 0, 1, 1, 1, 4, 1, 4, 4, 4, 4,
        4, 3, 1, 1, 1, 0, 3, 3, 1, 1, 2, 3, 1, 0, 3, 1, 0, 1, 4, 3, 2, 2, 1, 1,
        4, 2, 3, 0, 1, 2, 0, 3, 2, 0, 0, 0, 0, 0, 2, 3, 1, 2, 1, 4, 4, 1, 2, 2,
        2, 3, 2, 3, 0, 1, 2, 1, 3, 1, 1,