### Introduction

This notebook is intended to train a Decision Transformer using offline data gathered from exploring the CarRacing-v2 environment with a pre-trained DQN model.

### Install initial environment in Google Colab

In [1]:
import sys
import os

if 'google.colab' in sys.modules:
  if not os.path.exists('/content/.already_installed'):
    !git clone https://github.com/FlutterbaseDotCom/hdt
    !apt-get install -y swig
    !pip install -r hdt/requirements.txt
    with open('/content/.already_installed', 'w') as f:
        f.write('done')
  %cd /content/hdt

### Load Dataset

In [2]:
#car_racing_15_100
#offline_car_racing_150_1000

import utils.storage as storage
features = storage.load_dataset('car_racing_15_100')
print(len(features["observations"]))

15


In [3]:
import gymnasium as gym
env =  gym.make('CarRacing-v2', continuous=False) #, render_mode='human'

In [6]:
from dataclasses import asdict, dataclass
import wandb
import os


@dataclass
class TrainConfig:
    # WANDB CONFIG
    wandb_id: str = "dt_23"
    wandb_name: str = "DT_23"
    model_save_name: str = "DT_MODEL_23"
    saved_model_version: str = "latest"
    save_steps: int = 100

    # TRAINING DATA CONFIG
    num_train_epochs: int = 100
    max_ep_len: int = 1000
    max_length: int = 10
    rtg_gamma: float = 1.0

    prefix: str = 'DT'
    log_interval: int = 5
    save_steps: int = 30
    per_device_train_batch_size: int = 64
    learning_rate: float = 0.0001
    weight_decay: float = 0.0001
    warmup_ratio: float = 0.1
    max_grad_norm: float = 0.25


trainConfig = TrainConfig()

os.environ["WANDB_DISABLED"] = "false"
os.environ['WANDB_NOTEBOOK_NAME'] = 'DT_train.ipynb'
os.environ["WANDB_LOG_MODEL"] = "checkpoint"

wandb.login(key="f060d3284088ffaf4624e2de8b236f39711a99a2") # move to .env!
wandb.init(resume=trainConfig.wandb_id,
           name = trainConfig.wandb_name,
           mode="online",
           entity="yakiv",
            project="CarRacingDT",
            #resume= "allow"
            config=asdict(trainConfig)
           )




0,1
train/epoch,▁
train/global_step,▁
train/total_flos,▁
train/train_loss,▁
train/train_runtime,▁
train/train_samples_per_second,▁
train/train_steps_per_second,▁

0,1
train/epoch,2.0
train/global_step,2.0
train/total_flos,318621562675200.0
train/train_loss,1.78058
train/train_runtime,4.1127
train/train_samples_per_second,7.294
train/train_steps_per_second,0.486


### Train

In [8]:

from cnn_decision_transformer.cnn_decision_transformer_trainable import CnnDecisionTransformerGymDataCollator, TrainableCnnDecisionTransformer
from cnn_decision_transformer.configuration import DecisionTransformerConfig
from utils.dataset_wrappers import DummyDataset
from utils.dataset_wrappers import CarRacingFeatureDataset
from transformers import Trainer, TrainingArguments

feature_dataset = CarRacingFeatureDataset(src=features)
collator = CnnDecisionTransformerGymDataCollator(feature_dataset, max_len=trainConfig.max_length,   max_ep_len=trainConfig.max_ep_len,)

dt_config = DecisionTransformerConfig(state_dim=collator.state_dim, act_dim=collator.act_dim,
                                      max_length = trainConfig.max_length,
                                      max_ep_len = trainConfig.max_ep_len,  
                                      )

model = TrainableCnnDecisionTransformer(dt_config)


training_args = TrainingArguments(
    output_dir="output/",
    report_to="wandb",
    save_steps= trainConfig.save_steps,
    remove_unused_columns=False,
    optim="adamw_torch",
    num_train_epochs=trainConfig.num_train_epochs,
    per_device_train_batch_size= trainConfig.per_device_train_batch_size,
    learning_rate= trainConfig.learning_rate,
    weight_decay= trainConfig.weight_decay,
    warmup_ratio= trainConfig.warmup_ratio,
    max_grad_norm= trainConfig.max_grad_norm,
    logging_steps= trainConfig.log_interval,
)

    

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=DummyDataset(len(feature_dataset)), #there is a 'hack'  - no need to pass actual data
    data_collator=collator,

)

trainer.train()


In [None]:
#play
import matplotlib.pyplot as plt
from IPython.display import display as ipy_display, clear_output
#import gymnasium as gym
# build the environment
max_ep_len = 1000
device = 'cpu'
model = model.to('cpu')
scale = 1000.0  # normalization for rewards/returns
TARGET_RETURN = 900 / scale  # evaluation is conditioned on a return of 12000, scaled accordingly

env =  gym.make('CarRacing-v2', render_mode='rgb_array', continuous=False) #, 

state_dim = 96*96*3
act_dim = 1
# Create the decision transformer model

# Interact with the environment and create a video
episode_return, episode_length = 0, 0
[state, _] = env.reset()
state = prepare_observation_array(state)
target_return = torch.tensor(TARGET_RETURN, device=device, dtype=torch.float32).reshape(1, 1)
states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32)
actions = torch.zeros((0, act_dim),  device=device, dtype=torch.long)
rewards = torch.zeros(0, device=device, dtype=torch.float32)
print_every = 10
iter = 0

timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)
for t in range(max_ep_len):
    iter += 1
    actions = torch.cat([actions, torch.zeros((1, act_dim), dtype=torch.long,  device=device)], dim=0)
    rewards = torch.cat([rewards, torch.zeros(1, device=device)])

    action = get_action(
        model,
        states,
        actions,
        rewards,
        target_return,
        timesteps,
    )
    
    action =   torch.argmax(action).item() # action.detach().cpu().numpy()
    
    actions[-1] = torch.tensor(action, dtype=torch.long) 

    state, reward, done, _, _ = env.step(action)
    
    if iter % print_every ==0:
      image = env.render()
      clear_output(wait=True)
      plt.imshow(image)
      plt.axis('off')  # Hide the axis
      display(plt.gcf())
    
    

    state = prepare_observation_array(state)
    cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim)
    states = torch.cat([states, cur_state], dim=0)
    rewards[-1] = reward

    pred_return = target_return[0, -1] - (reward / scale)
    target_return = torch.cat([target_return, pred_return.reshape(1, 1)], dim=1)
    timesteps = torch.cat([timesteps, torch.ones((1, 1), device=device, dtype=torch.long) * (t + 1)], dim=1)

    episode_return += reward
    episode_length += 1

    if done:
        break