In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('/u/shuhan/projects/vla')

In [148]:
from src.models.vlas.cont_obs_token_action_cot_unified_token import ContObsTokenActionCOTVLAUnifiedToken
from src.auto_labeling.highway_env.lane_change import LaneChangeTaskSpec
from transformers import AutoModelForCausalLM, AutoTokenizer

llm_model = 'HuggingFaceTB/SmolLM2-135M-Instruct'

llm_backbone = AutoModelForCausalLM.from_pretrained(llm_model)
tokenizer = AutoTokenizer.from_pretrained(llm_model)

loss_weight = {"action": 1.0, "obs": 0.0, 'reconst': 1.0, "cot": 1.0, "separator": 1.0, "rollout_stop": 1.0}
cot_mode = 'start'
cot_cfg = {'lanes_count': 5, 'max_hop': 4, 'cot_index_mode': 'both'}

if llm_model == 'gpt2':
  hidden_dim = 768
elif llm_model == 'HuggingFaceTB/SmolLM2-135M-Instruct':
  hidden_dim = 576
else:
  raise ValueError(f'Unknown LLM model: {llm_model}')

obs_dim = 25
num_actions = 5
mlp_layers = 2

task_spec_func = LaneChangeTaskSpec

model = ContObsTokenActionCOTVLAUnifiedToken(llm_backbone, tokenizer, task_spec_func, obs_dim, num_actions, hidden_dim, mlp_layers, loss_weight, cot_mode, cot_cfg, max_obs_len=50)


In [149]:
import torch
import os

ckpt = os.path.expanduser('~/results/vla/quick_run_cot_unified/start_cot_smolLM/lightning_logs/version_0/checkpoints/test_model.ckpt')

lg_ckpt = torch.load(ckpt, map_location='cpu', weights_only=True)['state_dict']
ori_ckpt = {}
for k, v in lg_ckpt.items():
    if k.startswith('vla.'):
        ori_ckpt[k[4:]] = v

model.load_state_dict(ori_ckpt)

<All keys matched successfully>

In [150]:
import gymnasium
import highway_env

lanes_cnt_5_cfg = {
    "lanes_count": 5
}
env = gymnasium.make("highway-fast-v0", render_mode='rgb_array', config=lanes_cnt_5_cfg)
curr_obs, _ = env.reset()

curr_obs = torch.tensor(curr_obs, dtype=torch.float32)

In [192]:
from IPython.display import HTML
import tqdm
import numpy as np
import gymnasium
import highway_env
from matplotlib import pyplot as plt

from transformers.cache_utils import DynamicCache

%matplotlib inline


lane_cnt = 5
env = gymnasium.make("highway-fast-v0", render_mode='rgb_array', config=lanes_cnt_5_cfg)
curr_obs, _ = env.reset()

lane_width = 1.0 / lane_cnt
abs_y = curr_obs[..., 2].copy()
abs_y[1:] += abs_y[:1]
abs_y += lane_width / 2

lane_ids = (abs_y / lane_width).astype(int)
ego_lane_id = lane_ids[0]

print(ego_lane_id)

3


In [193]:
goal_spec = f"Goal is to reach Lane 2. Need to go through path Lane 3 -> Lane 4 -> Lane 3 -> Lane 2."
# goal_spec = f"Goal is to reach Lane 1. Need to go through path Lane 4 -> Lane 3 -> Lane 2 -> Lane 1."
goal_spec

'Goal is to reach Lane 2. Need to go through path Lane 3 -> Lane 4 -> Lane 3 -> Lane 2.'

In [194]:


observations = []
actions = []
reward_names = ['collision_reward', 'right_lane_reward', 'high_speed_reward', 'on_road_reward']
rewards = {name: [] for name in reward_names}

lanes_cnt_5_cfg = {
    "lanes_count": 5
}
curr_obs = torch.tensor(curr_obs, dtype=torch.float32)

past_input_str = goal_spec
past_key_value = DynamicCache()
past_input_embeds = model.llm_backbone.get_input_embeddings()(model.llm_tokenizer(past_input_str, return_tensors='pt').input_ids.to(curr_obs.device))

generate_cfg = {'max_new_tokens': 100, 'do_sample': False}


rollout_length = 30  # Adjust
cot_inference_mode = 'start'

print(past_input_str)

for _ in range(rollout_length):
    update_str, update_embeddings = model.inference_step(past_input_embeds, past_input_str, past_key_value, curr_obs, cot_inference_mode, generate_cfg)

    print(update_str)

    if '<EndOfRollout>' in update_str:
        print('model called end of rollout!')
        break

    act_index = update_str.index('<Act_')
    act_id = int(update_str[act_index+5:act_index+6])

    past_input_str = past_input_str + update_str
    
    if past_input_embeds is None:
        past_input_embeds = update_embeddings
    else:
        past_input_embeds = torch.cat([past_input_embeds, update_embeddings], dim=1)

    obs, reward, done, truncated, info = env.step(act_id)
    lane_cnt = 5
    lane_width = 1.0 / lane_cnt
    abs_y = obs[..., 2].copy()
    abs_y[1:] += abs_y[:1]
    abs_y += lane_width / 2
    lane_ids = (abs_y / lane_width).astype(int)
    ego_lane_id = lane_ids[0]
    print('ego_lane_id:', ego_lane_id)

    curr_obs = torch.tensor(obs, dtype=torch.float32)


    for name in reward_names:
        rewards[name].append(info['rewards'][name])

    if done or truncated:
        if truncated:
            print('rollout successfully finished!')
        else:
            print('rollout failed!')
        break

print(past_input_str)

env.close()

avg_rewards = {name: np.mean(rewards[name]) for name in reward_names}


Goal is to reach Lane 2. Need to go through path Lane 3 -> Lane 4 -> Lane 3 -> Lane 2.
<BOO><Obs_0><EOO><BOT>Now at Lane 3. Follow Lane 4 -> Lane 3 -> Lane 2. Next is Lane 4. Action: turn right.<EOT><BOA><Act_3><EOA>
ego_lane_id: 3
<BOO><Obs_1><EOO><BOA><Act_3><EOA>
ego_lane_id: 3
<BOO><Obs_2><EOO><BOA><Act_1><EOA>
ego_lane_id: 3
<BOO><Obs_3><EOO><BOA><Act_1><EOA>
ego_lane_id: 3
<BOO><Obs_4><EOO><BOA><Act_1><EOA>
ego_lane_id: 3
<BOO><Obs_5><EOO><BOA><Act_2><EOA>
ego_lane_id: 4
<BOO><Obs_6><EOO><BOA><Act_2><EOA>
ego_lane_id: 4
<BOO><Obs_7><EOO><BOA><Act_1><EOA>
ego_lane_id: 4
<BOO><Obs_8><EOO><BOA><Act_2><EOA>
ego_lane_id: 4
<BOO><Obs_9><EOO><BOA><Act_0><EOA>
ego_lane_id: 3
<BOO><Obs_10><EOO><BOA><Act_4><EOA>
ego_lane_id: 3
<BOO><Obs_11><EOO><BOA><Act_3><EOA>
ego_lane_id: 3
<BOO><Obs_12><EOO><BOA><Act_3><EOA>
ego_lane_id: 3
<BOO><Obs_13><EOO><BOA><Act_3><EOA>
ego_lane_id: 3
<BOO><Obs_14><EOO><BOA><Act_0><EOA>
ego_lane_id: 2
<BOO><Obs_15><EOO><EndOfRollout><EndOfRollout><Act_2><EOA>
mode

In [139]:
update_str

'<BOO><Obs_7><EOO><BOT>Now at Lane 2. Follow Lane 3 -> Lane 2 -> Lane 3 -> Lane 3 -> Lane 2 -> Lane 3 -> Lane 2 -> Lane 3 -> Lane 2 -> Lane 3 -> Lane 3 -> Lane 2 -> Lane 3 -> Lane 2 -> Lane 3 -> Lane 2 -> Lane 3 -> Lane 2 -> Lane 2 -> Lane 3 -> Lane 2 -> Lane 3 -> Lane 2. Next'