### Introduction

This notebook is intended to train a Decision Transformer using offline data gathered from exploring the CarRacing-v2 environment with a pre-trained DQN model.

### Install initial environment in Google Colab

In [1]:
import sys
import os

if 'google.colab' in sys.modules:
  if not os.path.exists('/content/.already_installed'):
    !git clone https://github.com/YakivGalkin/cnn_decision_transformer
    !apt-get install -y swig
    !pip install -r cnn_decision_transformer/requirements.txt
    with open('/content/.already_installed', 'w') as f:
        f.write('done')
  %cd /content/cnn_decision_transformer

Cloning into 'cnn_decision_transformer'...
remote: Enumerating objects: 79, done.[K
remote: Counting objects: 100% (79/79), done.[K
remote: Compressing objects: 100% (61/61), done.[K
remote: Total 79 (delta 38), reused 53 (delta 16), pack-reused 0[K
Receiving objects: 100% (79/79), 1.31 MiB | 21.99 MiB/s, done.
Resolving deltas: 100% (38/38), done.
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
  swig4.0
Suggested packages:
  swig-doc swig-examples swig4.0-examples swig4.0-doc
The following NEW packages will be installed:
  swig swig4.0
0 upgraded, 2 newly installed, 0 to remove and 19 not upgraded.
Need to get 1,116 kB of archives.
After this operation, 5,542 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 swig4.0 amd64 4.0.2-1ubuntu1 [1,110 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/universe amd64 swig all 4.0.2-1ubuntu

### Load Dataset

In [2]:
#car_racing_15_100
#offline_car_racing_150_1000

import utils.storage as storage
features = storage.load_dataset('offline_car_racing_150_1000')
print(len(features["observations"]))

Downloading file from https://storage.googleapis.com/yakiv-dt-public/datasets/offline_car_racing_150_1000.hdf5 to ./downloaded_datasets/offline_car_racing_150_1000.hdf5
Download complete.
150


In [3]:
import gymnasium as gym
env =  gym.make('CarRacing-v2', continuous=False) #, render_mode='human'

In [10]:
from dataclasses import asdict, dataclass


@dataclass
class TrainConfig:
    num_train_epochs: int = 5000
    max_ep_len: int = 1000
    max_length: int = 10
    rtg_gamma: float = 1.0

    prefix: str = 'DT'
    log_interval: int = 50
    save_steps: int = 1000
    per_device_train_batch_size: int = 32
    learning_rate: float = 0.0001
    weight_decay: float = 0.0001
    warmup_ratio: float = 0.1
    max_grad_norm: float = 0.25

trainConfig = TrainConfig()


In [None]:

import wandb
os.environ["WANDB_DISABLED"] = "false"
os.environ['WANDB_NOTEBOOK_NAME'] = 'DT_train.ipynb'
os.environ["WANDB_LOG_MODEL"] = "checkpoint"

#wandb.login(key="f060d3284088ffaf4624e2de8b236f39711a99a2") # move to .env!
wandb.init( name = "vdt_001",
           mode="online",
           entity="yakiv",
            project="VDT",
            #resume= "allow"
            config=asdict(trainConfig)
           )


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112527844443321, max=1.0…

### Train

In [6]:

from visual_decision_transformer.visual_decision_transformer_trainable import VisualDecisionTransformerGymDataCollator, TrainableVisualDecisionTransformer
from visual_decision_transformer.configuration import DecisionTransformerConfig
from utils.dataset_wrappers import DummyDataset
from utils.dataset_wrappers import CarRacingFeatureDataset
from transformers import Trainer, TrainingArguments

feature_dataset = CarRacingFeatureDataset(src=features)
collator = VisualDecisionTransformerGymDataCollator(feature_dataset, max_len=trainConfig.max_length,   max_ep_len=trainConfig.max_ep_len,)

dt_config = DecisionTransformerConfig(state_dim=collator.state_dim, act_dim=collator.act_dim,
                                      max_length = trainConfig.max_length,
                                      max_ep_len = trainConfig.max_ep_len,
                                      )

model = TrainableVisualDecisionTransformer(dt_config)


training_args = TrainingArguments(
    output_dir="output/",
    report_to="wandb",
    save_steps= trainConfig.save_steps,
    remove_unused_columns=False,
    optim="adamw_torch",
    num_train_epochs=trainConfig.num_train_epochs,
    per_device_train_batch_size= trainConfig.per_device_train_batch_size,
    learning_rate= trainConfig.learning_rate,
    weight_decay= trainConfig.weight_decay,
    warmup_ratio= trainConfig.warmup_ratio,
    max_grad_norm= trainConfig.max_grad_norm,
    logging_steps= trainConfig.log_interval,
)



trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=DummyDataset(len(feature_dataset)), #there is a 'hack'  - no need to pass actual data
    data_collator=collator,

)

trainer.train()


  from tensorflow.tsl.python.lib.core import pywrap_ml_dtypes


Downloading file from https://storage.googleapis.com/yakiv-dt-public/models/nature_cnn_dql_pretrained.pt to ./downloaded_models/nature_cnn_dql_pretrained.pt
Download complete.


Step,Training Loss
50,1.494
100,1.2118
150,1.0118
200,0.9154
250,0.8294
300,0.7711
350,0.7082
400,0.6687
450,0.6268
500,0.6217


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-30)... Done. 0.1s
[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-60)... Done. 0.1s
[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-90)... Done. 0.1s
[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-120)... Done. 0.1s
[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-150)... Done. 0.1s
[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-180)... Done. 0.1s
[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-210)... Done. 0.1s
[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-240)... Done. 0.1s
[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-270)... Done. 0.1s
[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-300)... Done. 0.1s
[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-330)... Done. 0.1s
[34m[1mwandb[0m: Addi

TrainOutput(global_step=2500, training_loss=0.5026930274963379, metrics={'train_runtime': 154.4128, 'train_samples_per_second': 485.711, 'train_steps_per_second': 16.19, 'total_flos': 7.951838339684966e+17, 'train_loss': 0.5026930274963379, 'epoch': 500.0})

In [None]:
#play
import matplotlib.pyplot as plt
from IPython.display import display as ipy_display, clear_output
#import gymnasium as gym
# build the environment
max_ep_len = 1000
device = 'cpu'
model = model.to('cpu')
scale = 1000.0  # normalization for rewards/returns
TARGET_RETURN = 900 / scale  # evaluation is conditioned on a return of 12000, scaled accordingly

env =  gym.make('CarRacing-v2', render_mode='rgb_array', continuous=False) #,

state_dim = 96*96*3
act_dim = 1
# Create the decision transformer model

# Interact with the environment and create a video
episode_return, episode_length = 0, 0
[state, _] = env.reset()
state = prepare_observation_array(state)
target_return = torch.tensor(TARGET_RETURN, device=device, dtype=torch.float32).reshape(1, 1)
states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32)
actions = torch.zeros((0, act_dim),  device=device, dtype=torch.long)
rewards = torch.zeros(0, device=device, dtype=torch.float32)
print_every = 10
iter = 0

timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)
for t in range(max_ep_len):
    iter += 1
    actions = torch.cat([actions, torch.zeros((1, act_dim), dtype=torch.long,  device=device)], dim=0)
    rewards = torch.cat([rewards, torch.zeros(1, device=device)])

    action = get_action(
        model,
        states,
        actions,
        rewards,
        target_return,
        timesteps,
    )

    action =   torch.argmax(action).item() # action.detach().cpu().numpy()

    actions[-1] = torch.tensor(action, dtype=torch.long)

    state, reward, done, _, _ = env.step(action)

    if iter % print_every ==0:
      image = env.render()
      clear_output(wait=True)
      plt.imshow(image)
      plt.axis('off')  # Hide the axis
      display(plt.gcf())



    state = prepare_observation_array(state)
    cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim)
    states = torch.cat([states, cur_state], dim=0)
    rewards[-1] = reward

    pred_return = target_return[0, -1] - (reward / scale)
    target_return = torch.cat([target_return, pred_return.reshape(1, 1)], dim=1)
    timesteps = torch.cat([timesteps, torch.ones((1, 1), device=device, dtype=torch.long) * (t + 1)], dim=1)

    episode_return += reward
    episode_length += 1

    if done:
        break