In [3]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('/u/shuhan/projects/vla')

In [62]:
from src.models.vlas.cont_obs_token_action_cot_unified_token_collision import ContObsTokenActionCOTVLAUnifiedTokenCollision
from src.auto_labeling.highway_env.lane_change import LaneChangeTaskSpecCollision
from transformers import AutoModelForCausalLM, AutoTokenizer

llm_model = 'HuggingFaceTB/SmolLM2-135M-Instruct'

llm_backbone = AutoModelForCausalLM.from_pretrained(llm_model)
tokenizer = AutoTokenizer.from_pretrained(llm_model)

In [63]:
loss_weight = {"action": 1.0, "obs": 0.0, 'reconst': 1.0, "cot": 1.0, "separator": 1.0, "rollout_stop": 1.0, "wm": 1.0}
cot_mode = 'start'

cot_cfg = {'lanes_count': 5, 'max_hop': 4, 'cot_index_mode': 'both', 'action_sample_mode': 'future', 'safe_reflect_rate': 0.3, 'collide_reflect_rate': 0.8, 'collide_rewind_rate': 0.8}

if llm_model == 'gpt2':
  hidden_dim = 768
elif llm_model == 'HuggingFaceTB/SmolLM2-135M-Instruct':
  hidden_dim = 576
else:
  raise ValueError(f'Unknown LLM model: {llm_model}')

obs_dim = 25
num_actions = 5
mlp_layers = 2

task_spec_func = LaneChangeTaskSpecCollision

# use_wm = False
use_wm = True

model = ContObsTokenActionCOTVLAUnifiedTokenCollision(llm_backbone, tokenizer, task_spec_func, obs_dim, num_actions, hidden_dim, mlp_layers, loss_weight, cot_mode, cot_cfg, max_obs_len=50, use_wm=use_wm)

In [68]:
# obtain the avalibale GPU id and create a device
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [64]:
import os
import torch

if use_wm:
  ckpt = '~/results/vla/quick_run_cot_unified_collision/with_wm_cr_0.8_re_0.2_sr_0.2/lightning_logs/version_0/checkpoints/test_model.ckpt'
else:
  ckpt = '~/results/vla/quick_run_cot_unified_collision/no_wm_cr_0.8_re_0.8_sr_0.2/lightning_logs/version_2/checkpoints/test_model.ckpt'

ckpt = os.path.expanduser(ckpt)

lg_ckpt = torch.load(ckpt, map_location='cpu', weights_only=True)['state_dict']
ori_ckpt = {}
for k, v in lg_ckpt.items():
    if k.startswith('vla.'):
        ori_ckpt[k[4:]] = v

model.load_state_dict(ori_ckpt)

<All keys matched successfully>

In [34]:
import sys
import torch
from torch.utils.data import DataLoader
sys.path.append('/u/shuhan/projects/vla')

from src.environments.highway_env.dataset import HighwayCollisionDataset, collate_fn_collision

dataset = HighwayCollisionDataset(data_dir='/storage/Datasets/highway_env/highway_fast_v0_dqn_meta_action_5_lanes/rollouts_train_collision', overfit=True)

In [35]:
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn_collision)

for batch in dataloader:
  batch_data = batch
  break

In [36]:
loss_dict, batch_input_embeds, batch_label_ids, batch_input_ids, llm_output = model(batch_data)


In [11]:
import pickle

goal_spec_dataset_path = '/u/shuhan/projects/vla/data/highway_env/lane_change_goal_spec_data.pkl'
with open(goal_spec_dataset_path, 'rb') as f:
  goal_spec_dataset = pickle.load(f)


In [12]:
from IPython.display import HTML
import tqdm
import numpy as np
import gymnasium
import highway_env
from matplotlib import pyplot as plt
import random
from difflib import SequenceMatcher
import copy

from transformers.cache_utils import DynamicCache

%matplotlib inline

def get_ego_lane_id(curr_obs):
  lane_cnt = 5
  lane_width = 1.0 / lane_cnt
  abs_y = curr_obs[..., 2].copy()
  abs_y[1:] += abs_y[:1]
  abs_y += lane_width / 2
  lane_ids = (abs_y / lane_width).astype(int)
  ego_lane_id = lane_ids[0]
  return ego_lane_id

def compute_path_score(goal_path: list[int], ego_lane_ids: list[int]):
  # exact match
  exact_match_count = sum(1 for g, e in zip(goal_path, ego_lane_ids) if g == e)
  exact_match_score = exact_match_count / len(goal_path)

  # subset coverage
  sequence_matcher = SequenceMatcher(None, goal_path, ego_lane_ids)
  longest_match_length = sequence_matcher.find_longest_match(0, len(goal_path), 0, len(ego_lane_ids)).size
  subset_coverage = longest_match_length / len(goal_path)

  return exact_match_score, subset_coverage

def get_wm_obs_from_env(env, action_id):
  '''
  Get the WM observation from the environment. The environment itself is not affected.
  '''

  start_env_state = copy.deepcopy(env.__dict__)

  wm_env = gymnasium.make("highway-fast-v0", render_mode='rgb_array', config={"lanes_count": 5})
  _, _ = wm_env.reset()
  wm_env.__dict__.update(start_env_state) 

  wm_obs, _, has_collision, _, _ = wm_env.step(action_id)

  if has_collision:
    print('wm collision!')
  else:
    print('wm safe!')
  return wm_obs, has_collision

In [66]:
# rollout pipeline for wm model
use_wm = True
# use_wm = False
# wm_mode = 'model' # 'env'
# wm_mode = 'env' # 'model'
wm_mode = 'model' # 'model'
cot_mode = 'pred' # 'always', 'never'
# cot_mode = 'always' # 'always', 'never'
# cot_mode = 'never' # 'always', 'never'
# cot_mode = 'always' # 'always', 'never'


wm_init_collision_cnt = 0 # initial action has collision
model_wm_cnt = 0 # model wm used
model_rewind_cnt = 0 # model decide to rewind
model_rewind_collision_cnt = 0 # model rewind has collision
wm_init_collision_model_rewind_cnt = 0 # model collision after rewind

env = gymnasium.make("highway-fast-v0", render_mode='rgb_array', config={"lanes_count": 5})
curr_obs, _ = env.reset()
ego_lane_id = get_ego_lane_id(curr_obs)

sampled_path_info = random.choice(goal_spec_dataset[ego_lane_id])

goal_spec = sampled_path_info['goal_spec']
hop_lane_ids = sampled_path_info['hop_lane_ids']

start_id = hop_lane_ids[0]
goal_id = hop_lane_ids[-1]

curr_obs = torch.tensor(curr_obs, dtype=torch.float32)

# max_rollout_length = 30
max_rollout_length = 30

ego_lane_ids = [start_id]
actions = []
model_failed = False
rollout_collision = False

past_input_str = goal_spec

print(goal_spec)

past_key_value = DynamicCache()
past_input_embeds = model.llm_backbone.get_input_embeddings()(model.llm_tokenizer(past_input_str, return_tensors='pt').input_ids.to(curr_obs.device))

generate_cfg = {'max_new_tokens': 100, 'do_sample': False}

for _ in range(max_rollout_length):
  # step 1: obtain initial action prediction
  init_act_str, init_act_embeddings = model.init_action_inference(past_input_embeds, past_input_str, curr_obs, generate_cfg)
  
  print('\tinit_act_str:', init_act_str)

  if '<EndOfRollout>' in init_act_str:
    print('\tmodel called end of rollout!')
    break

  if '<Act_' not in init_act_str:
    print('\tno action token in the initial action inference string!')
    model_failed = True
    break

  init_act_index = init_act_str.index('<Act_')
  init_act_id = int(init_act_str[init_act_index+5:init_act_index+6])

  past_input_str = past_input_str + init_act_str
  past_input_embeds = torch.cat([past_input_embeds, init_act_embeddings], dim=1)

  # step 2: obtain cot start token, decide whether to use cot or not
  cot_token_str, cot_token_embeddings = model.cot_start_inference(past_input_embeds, past_input_str, cot_mode, use_wm)
  
  if len(cot_token_str) > 0:
    past_input_str = past_input_str + cot_token_str
    past_input_embeds = torch.cat([past_input_embeds, cot_token_embeddings], dim=1)

  print('\tcot_token_str:', cot_token_str)

  if '<COMMIT>' in cot_token_str:
    final_act_id = init_act_id
  else:
    # step 3: obtain world model prediction
    if "<BWM>" in cot_token_str and use_wm:
      model_wm_cnt += 1
      if wm_mode == 'model':
        wm_str, wm_embeddings = model.cot_append_wm_embeddings(past_input_embeds, past_input_str, None)
      elif wm_mode == 'env':
        wm_obs, wm_has_collision = get_wm_obs_from_env(env, init_act_id)
        wm_obs = torch.tensor(wm_obs, dtype=torch.float32).to(curr_obs.device)
        wm_str, wm_embeddings = model.cot_append_wm_embeddings(past_input_embeds, past_input_str, wm_obs)
      
      past_input_str = past_input_str + wm_str
      past_input_embeds = torch.cat([past_input_embeds, wm_embeddings], dim=1)
    
    # step 4: obtain cot commit token
    commit_str, commit_embeddings = model.cot_commit_inference(past_input_embeds, past_input_str, generate_cfg)
    past_input_str = past_input_str + commit_str
    past_input_embeds = torch.cat([past_input_embeds, commit_embeddings], dim=1)

    print('\tcommit_str:', commit_str)

    _, wm_init_collision = get_wm_obs_from_env(env, init_act_id)
    wm_init_collision_cnt += int(wm_init_collision)

    if '<COMMIT>' not in commit_str:
      print('\tcot commit token is not <COMMIT>!')
      model_failed = True
      break
    elif '<BACKSPACE>' in commit_str and '<Act_' in commit_str:
      # rewind and update action
      print('\trewind and update action!')
      model_rewind_cnt += 1
      final_act_id = int(commit_str[commit_str.index('<Act_')+5:commit_str.index('<Act_')+6])
      
      _, wm_final_collision = get_wm_obs_from_env(env, final_act_id)
      wm_init_collision_model_rewind_cnt += int(wm_init_collision)
      model_rewind_collision_cnt += int(wm_final_collision)
    else:
      print('\tsafe, continue to use the initial action!')
      final_act_id = init_act_id
      
  # step 5: take action
  obs, reward, has_collision, truncated, info = env.step(final_act_id)
  ego_lane_id = get_ego_lane_id(obs)
  
  print(f'step: {len(actions)}, action: {final_act_id}, ego_lane_id: {ego_lane_id}')

  actions.append(final_act_id)
  ego_lane_ids.append(ego_lane_id)

  curr_obs = torch.tensor(obs, dtype=torch.float32)

  if truncated:
      print('rollout finished!')
      break

  if has_collision:
      rollout_collision = True
      print('rollout collision!')
      break

cot_stats = {}

cot_stats['collision_detect_recall'] = (wm_init_collision_model_rewind_cnt / wm_init_collision_cnt) if wm_init_collision_cnt > 0 else None
cot_stats['rewind_precision'] = (wm_init_collision_model_rewind_cnt / model_rewind_cnt) if model_rewind_cnt > 0 else None
cot_stats['rewind_collision_avoid_rate'] = 1 - (model_rewind_collision_cnt / model_rewind_cnt) if model_rewind_cnt > 0 else None
cot_stats['model_rewind_ratio'] = (model_rewind_cnt / model_wm_cnt) if model_wm_cnt > 0 else None

Goal is to reach Lane 3. Need to go through path Lane 1 -> Lane 0 -> Lane 1 -> Lane 2 -> Lane 3.
	init_act_str: <BOO><Obs_0><EOO><BOA><Act_3><EOA>
	cot_token_str: <COMMIT>
step: 0, action: 3, ego_lane_id: 1
	init_act_str: <BOO><Obs_1><EOO><BOA><Act_3><EOA>
	cot_token_str: <COMMIT>
step: 1, action: 3, ego_lane_id: 1
	init_act_str: <BOO><Obs_2><EOO><BOA><Act_1><EOA>
	cot_token_str: <BWM>
	commit_str: <BOT>Safe<EOT><COMMIT>
wm safe!
	safe, continue to use the initial action!
step: 2, action: 1, ego_lane_id: 1
	init_act_str: <BOO><Obs_3><EOO><BOA><Act_1><EOA>
	cot_token_str: <BWM>
	commit_str: <BOT>Safe<EOT><COMMIT>
wm safe!
	safe, continue to use the initial action!
step: 3, action: 1, ego_lane_id: 1
	init_act_str: <BOO><Obs_4><EOO><BOA><Act_1><EOA>
	cot_token_str: <COMMIT>
step: 4, action: 1, ego_lane_id: 1
	init_act_str: <BOO><Obs_5><EOO><BOA><Act_2><EOA>
	cot_token_str: <COMMIT>
step: 5, action: 2, ego_lane_id: 2
	init_act_str: <BOO><Obs_6><EOO><BOA><Act_1><EOA>
	cot_token_str: <BWM>
	

In [67]:
cot_stats

{'collision_detect_recall': None,
 'rewind_precision': None,
 'rewind_collision_avoid_rate': None,
 'model_rewind_ratio': 0.0}

In [32]:

def rollout_one_episode(model, goal_spec_dataset, cot_inference_mode: str):
    env = gymnasium.make("highway-fast-v0", render_mode='rgb_array', config={"lanes_count": 5})
    curr_obs, _ = env.reset()
    ego_lane_id = get_ego_lane_id(curr_obs)

    sampled_path_info = random.choice(goal_spec_dataset[ego_lane_id])

    goal_spec = sampled_path_info['goal_spec']
    hop_lane_ids = sampled_path_info['hop_lane_ids']

    start_id = hop_lane_ids[0]
    goal_id = hop_lane_ids[-1]

    curr_obs = torch.tensor(curr_obs, dtype=torch.float32)

    max_rollout_length = 30

    ego_lane_ids = [start_id]
    actions = []
    model_failed = False
    rollout_collision = False

    past_input_str = goal_spec
    past_key_value = DynamicCache()
    past_input_embeds = model.llm_backbone.get_input_embeddings()(model.llm_tokenizer(past_input_str, return_tensors='pt').input_ids.to(curr_obs.device))

    generate_cfg = {'max_new_tokens': 100, 'do_sample': False}

    # print(past_input_str)

    for _ in range(max_rollout_length):
        update_str, update_embeddings = model.inference_step(past_input_embeds, past_input_str, past_key_value, curr_obs, cot_inference_mode, generate_cfg)
        # print(update_str)

        past_input_str = past_input_str + update_str
        
        if past_input_embeds is None:
            past_input_embeds = update_embeddings
        else:
            past_input_embeds = torch.cat([past_input_embeds, update_embeddings], dim=1)

        if '<EndOfRollout>' in update_str:
            print('model called end of rollout!')
            break

        if '<Act_' not in update_str:
            print('no action token in the update string!')
            model_failed = True
            break

        act_index = update_str.index('<Act_')
        act_id = int(update_str[act_index+5:act_index+6])


        obs, reward, has_collision, truncated, info = env.step(act_id)
        ego_lane_id = get_ego_lane_id(obs)
        
        # print(f'step: {len(actions)}, action: {act_id}, ego_lane_id: {ego_lane_id}')

        actions.append(act_id)
        ego_lane_ids.append(ego_lane_id)

        curr_obs = torch.tensor(obs, dtype=torch.float32)

        if truncated:
            print('rollout finished!')
            break

        if has_collision:
            rollout_collision = True
            print('rollout collision!')
            break

    # remove repeating lane ids
    ego_lane_ids = [ego_lane_ids[0]] + [ego_lane_ids[i] for i in range(1, len(ego_lane_ids)) if ego_lane_ids[i] != ego_lane_ids[i-1]]

    token_count = past_input_embeds.shape[1]
    action_count = len(actions)
    reached_goal = (ego_lane_ids[-1] == goal_id) and not (model_failed or rollout_collision) and len(ego_lane_ids) == len(hop_lane_ids)
    exact_match_score, subset_coverage = compute_path_score(hop_lane_ids, ego_lane_ids)

    exceeded_length = max(0, len(ego_lane_ids) - len(hop_lane_ids))

    env.close()

    scores = {'token_count': token_count, 'action_count': action_count, 'exact_match_score': exact_match_score, 'subset_coverage': subset_coverage, 'model_failed': model_failed, 'rollout_collision': rollout_collision, 'reached_goal': reached_goal, 'exceeded_length': exceeded_length}

    return scores, past_input_str


In [None]:
# Set up logging
log_dir = 'logs'
os.makedirs(log_dir, exist_ok=True)

In [114]:
import tqdm
import logging
import os


# cot_inference_mode = 'pred'
# cot_inference_mode = 'start'
cot_inference_mode = 'never'
exp_name = f'{model_name}_{cot_inference_mode}'

save_dir = '/u/shuhan/projects/vla/data/highway_env/rollout_experiment'
os.makedirs(save_dir, exist_ok=True)

logging.basicConfig(
    filename=os.path.join(save_dir, f'{exp_name}.log'),
    level=logging.INFO,
    format='%(asctime)s - %(message)s'
)

all_scores = {'token_count': [], 'action_count': [], 'exact_match_score': [], 'subset_coverage': [], 'model_failed': [], 'rollout_collision': [], 'reached_goal': [], 'exceeded_length': []}
all_past_input_str = []

rollout_count = 2


for rollout_idx in tqdm.tqdm(range(rollout_count)):
    scores, past_input_str = rollout_one_episode(model, goal_spec_dataset, cot_inference_mode)
    for k, v in scores.items():
        all_scores[k].append(v)
    all_past_input_str.append(past_input_str)

    logging.info(f'rollout {rollout_idx} done')

    for k, v in all_scores.items():
        logging.info(f'\t {k}: {np.mean(v)}')

logging.info('final results:')
for k, v in all_scores.items():
    logging.info(f'\t {k}: {np.mean(v)}')

 50%|█████     | 1/2 [00:01<00:01,  1.51s/it]

rollout collision!


100%|██████████| 2/2 [00:07<00:00,  3.97s/it]

rollout collision!



