In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('/u/shuhan/projects/vla')

In [11]:
from src.models.vlas.cont_obs_token_action_cot_unified_token import ContObsTokenActionCOTVLAUnifiedToken
from src.auto_labeling.highway_env.lane_change import LaneChangeTaskSpec
from transformers import AutoModelForCausalLM, AutoTokenizer

llm_model = 'HuggingFaceTB/SmolLM2-135M-Instruct'

llm_backbone = AutoModelForCausalLM.from_pretrained(llm_model)
tokenizer = AutoTokenizer.from_pretrained(llm_model)

loss_weight = {"action": 1.0, "obs": 0.0, 'reconst': 1.0, "cot": 1.0, "separator": 1.0, "rollout_stop": 1.0}
cot_mode = 'start'
cot_cfg = {'lanes_count': 5, 'max_hop': 4, 'cot_index_mode': 'both'}

if llm_model == 'gpt2':
  hidden_dim = 768
elif llm_model == 'HuggingFaceTB/SmolLM2-135M-Instruct':
  hidden_dim = 576
else:
  raise ValueError(f'Unknown LLM model: {llm_model}')

obs_dim = 25
num_actions = 5
mlp_layers = 2

task_spec_func = LaneChangeTaskSpec

model = ContObsTokenActionCOTVLAUnifiedToken(llm_backbone, tokenizer, task_spec_func, obs_dim, num_actions, hidden_dim, mlp_layers, loss_weight, cot_mode, cot_cfg, max_obs_len=50)


In [12]:
import torch
import os

ckpt = os.path.expanduser('~/results/vla/quick_run_cot_unified/start_cot_smolLM/lightning_logs/version_0/checkpoints/test_model.ckpt')

lg_ckpt = torch.load(ckpt, map_location='cpu', weights_only=True)['state_dict']
ori_ckpt = {}
for k, v in lg_ckpt.items():
    if k.startswith('vla.'):
        ori_ckpt[k[4:]] = v

model.load_state_dict(ori_ckpt)

<All keys matched successfully>

In [13]:

from src.environments.highway_env.dataset import HighwayDataset, collate_fn
from torch.utils.data import DataLoader

data_folder = '/storage/Datasets/highway_env/highway_fast_v0_dqn_meta_action_5_lanes/rollouts_train'
dataset = HighwayDataset(data_folder)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)

In [23]:
for batch in dataloader:
  break
obs, act, valid_mask = batch
print(obs.shape, act.shape, valid_mask.shape)

torch.Size([1, 20, 5, 5]) torch.Size([1, 20]) torch.Size([1, 20])


In [55]:
loss_dict, batch_input_embeds, batch_label_ids, batch_input_ids, llm_output = model.forward(obs, act, valid_mask)
padding_mask = batch_label_ids[0] == -100

In [111]:
input_str = model.llm_tokenizer.decode(batch_input_ids[0], skip_special_tokens=False)
eoo_token_id = model.llm_tokenizer('<EOO>')['input_ids'][0]
eot_token_id = model.llm_tokenizer('<EOT>')['input_ids'][0]
eoa_token_id = model.llm_tokenizer('<EOA>')['input_ids'][0]
init_context_len = batch_input_ids[0].cpu().tolist().index(eoo_token_id)

In [114]:
batch_label_ids[0][padding_mask] = tokenizer.pad_token_id

gt_str = model.llm_tokenizer.decode(batch_label_ids[0][init_context_len:], skip_special_tokens=False)
print(gt_str)

<BOT>Now at Lane 2. Follow Lane 1 -> Lane 2 -> Lane 3 -> Lane 4. Next is Lane 1. Action: turn left.<EOT><BOA><Act_3><EOA><|endoftext|><|endoftext|><|endoftext|><BOA><Act_0><EOA><|endoftext|><|endoftext|><|endoftext|><BOA><Act_3><EOA><|endoftext|><|endoftext|><|endoftext|><BOA><Act_2><EOA><|endoftext|><|endoftext|><|endoftext|><BOA><Act_1><EOA><|endoftext|><|endoftext|><|endoftext|><BOA><Act_1><EOA><|endoftext|><|endoftext|><|endoftext|><BOA><Act_2><EOA><|endoftext|><|endoftext|><|endoftext|><BOA><Act_2><EOA><|endoftext|><|endoftext|><|endoftext|><EndOfRollout>


In [116]:
model_pred_ids = torch.argmax(llm_output.logits, axis=-1)
model_pred_ids[0][padding_mask] = tokenizer.pad_token_id
model_pred_str = model.llm_tokenizer.decode(model_pred_ids[0][init_context_len:], skip_special_tokens=False)

print(model_pred_str)

<BOT>Now at Lane 2. Follow Lane 1 -> Lane 2 -> Lane 3 -> Lane 4. Next is Lane 1. Action: turn left.<EOT><BOA><Act_3><EOA><|endoftext|><|endoftext|><|endoftext|><BOA><Act_0><EOA><|endoftext|><|endoftext|><|endoftext|><BOA><Act_3><EOA><|endoftext|><|endoftext|><|endoftext|><BOA><Act_2><EOA><|endoftext|><|endoftext|><|endoftext|><BOA><Act_1><EOA><|endoftext|><|endoftext|><|endoftext|><BOA><Act_1><EOA><|endoftext|><|endoftext|><|endoftext|><BOA><Act_2><EOA><|endoftext|><|endoftext|><|endoftext|><BOA><Act_2><EOA><|endoftext|><|endoftext|><|endoftext|><EndOfRollout>


In [117]:
model.llm_tokenizer.decode(batch_input_ids[0][:init_context_len+1], skip_special_tokens=False)

'Goal is to reach Lane 4. Need to go through path Lane 2 -> Lane 1 -> Lane 2 -> Lane 3 -> Lane 4.<BOO><Obs_0><EOO>'

In [118]:
context_input_embeds = batch_input_embeds[:, :init_context_len+1]

rollout_output = model.llm_backbone.generate(inputs_embeds=context_input_embeds, max_new_tokens=100, do_sample=False, eos_token_id=eoa_token_id)
rollout_output_str = model.llm_tokenizer.decode(rollout_output[0], skip_special_tokens=False)
print(rollout_output_str)

<BOT>Now at Lane 2. Follow Lane 1 -> Lane 2 -> Lane 3 -> Lane 4. Next is Lane 1. Action: turn left.<EOT><BOA><Act_3><EOA>
