### Introduction

This notebook is intended to train a Decision Transformer using offline data gathered from exploring the CarRacing-v2 environment with a pre-trained DQN model.

### Install initial environment in Google Colab

In [1]:
import sys
import os

if 'google.colab' in sys.modules:
  if not os.path.exists('/content/.already_installed'):
    !git clone https://github.com/YakivGalkin/cnn_decision_transformer
    !apt-get install -y swig
    !pip install -r cnn_decision_transformer/requirements.txt
    with open('/content/.already_installed', 'w') as f:
        f.write('done')
  %cd /content/cnn_decision_transformer

### Load Dataset

In [2]:
#car_racing_15_100
#offline_car_racing_150_1000

import utils.storage as storage
features = storage.load_dataset('offline_car_racing_150_1000')
print(len(features["observations"]))

15


In [3]:
import gymnasium as gym
env =  gym.make('CarRacing-v2', continuous=False) #, render_mode='human'

In [4]:
from dataclasses import asdict, dataclass
import wandb
import os


@dataclass
class TrainConfig:
    # WANDB CONFIG
    wandb_id: str = "dt_25"
    wandb_name: str = "DT_25"
    model_save_name: str = "DT_MODEL_25"
    saved_model_version: str = "latest"
    save_steps: int = 1000

    # TRAINING DATA CONFIG
    num_train_epochs: int = 100
    max_ep_len: int = 1000
    max_length: int = 10
    rtg_gamma: float = 1.0

    prefix: str = 'DT'
    log_interval: int = 5
    save_steps: int = 30
    per_device_train_batch_size: int = 1
    learning_rate: float = 0.0001
    weight_decay: float = 0.0001
    warmup_ratio: float = 0.1
    max_grad_norm: float = 0.25


trainConfig = TrainConfig()

os.environ["WANDB_DISABLED"] = "false"
os.environ['WANDB_NOTEBOOK_NAME'] = 'DT_train.ipynb'
os.environ["WANDB_LOG_MODEL"] = "checkpoint"

wandb.login(key="f060d3284088ffaf4624e2de8b236f39711a99a2") # move to .env!
wandb.init(resume=trainConfig.wandb_id,
           name = trainConfig.wandb_name,
           mode="online",
           entity="yakiv",
            project="CarRacingDT",
            #resume= "allow"
            config=asdict(trainConfig)
           )


[34m[1mwandb[0m: Currently logged in as: [33myakiv[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/jacob/.netrc


### Train

In [5]:

from visual_decision_transformer.visual_decision_transformer_trainable import VisualDecisionTransformerGymDataCollator, TrainableVisualDecisionTransformer
from visual_decision_transformer.configuration import DecisionTransformerConfig
from utils.dataset_wrappers import DummyDataset
from utils.dataset_wrappers import CarRacingFeatureDataset
from transformers import Trainer, TrainingArguments

feature_dataset = CarRacingFeatureDataset(src=features)
collator = VisualDecisionTransformerGymDataCollator(feature_dataset, max_len=trainConfig.max_length,   max_ep_len=trainConfig.max_ep_len,)

dt_config = DecisionTransformerConfig(state_dim=collator.state_dim, act_dim=collator.act_dim,
                                      max_length = trainConfig.max_length,
                                      max_ep_len = trainConfig.max_ep_len,  
                                      )

model = TrainableVisualDecisionTransformer(dt_config)


training_args = TrainingArguments(
    output_dir="output/",
    report_to="wandb",
    save_steps= trainConfig.save_steps,
    remove_unused_columns=False,
    optim="adamw_torch",
    num_train_epochs=trainConfig.num_train_epochs,
    per_device_train_batch_size= trainConfig.per_device_train_batch_size,
    learning_rate= trainConfig.learning_rate,
    weight_decay= trainConfig.weight_decay,
    warmup_ratio= trainConfig.warmup_ratio,
    max_grad_norm= trainConfig.max_grad_norm,
    logging_steps= trainConfig.log_interval,
)

    

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=DummyDataset(len(feature_dataset)), #there is a 'hack'  - no need to pass actual data
    data_collator=collator,

)

trainer.train()


  from .autonotebook import tqdm as notebook_tqdm
  0%|          | 0/1500 [00:00<?, ?it/s]

Collator sz:1: started at 12:32:24
Collator sz:1: finished at 12:32:24 and took 3ms
Collator sz:1: started at 12:32:24
Collator sz:1: finished at 12:32:24 and took 1ms
Trainable Forward pass: started at 12:32:25
loss: 1.5905131101608276
Trainable Forward pass: finished at 12:32:25 and took 113ms


  0%|          | 1/1500 [00:01<44:19,  1.77s/it]

Collator sz:1: started at 12:32:26
Collator sz:1: finished at 12:32:26 and took 2ms
Trainable Forward pass: started at 12:32:26
loss: 1.6781355142593384
Trainable Forward pass: finished at 12:32:26 and took 21ms


  0%|          | 2/1500 [00:02<25:53,  1.04s/it]

Collator sz:1: started at 12:32:27
Collator sz:1: finished at 12:32:27 and took 1ms
Trainable Forward pass: started at 12:32:27
loss: 1.6786863803863525
Trainable Forward pass: finished at 12:32:27 and took 20ms


  0%|          | 3/1500 [00:02<18:02,  1.38it/s]

Collator sz:1: started at 12:32:27
Collator sz:1: finished at 12:32:27 and took 1ms
Trainable Forward pass: started at 12:32:27
loss: 1.6703516244888306
Trainable Forward pass: finished at 12:32:27 and took 20ms


  0%|          | 4/1500 [00:02<14:21,  1.74it/s]

Collator sz:1: started at 12:32:27
Collator sz:1: finished at 12:32:27 and took 1ms
Trainable Forward pass: started at 12:32:28
loss: 1.6779769659042358
Trainable Forward pass: finished at 12:32:28 and took 20ms


  0%|          | 5/1500 [00:03<12:18,  2.02it/s]

{'loss': 1.6591, 'learning_rate': 3.3333333333333333e-06, 'epoch': 0.33}
Collator sz:1: started at 12:32:28
Collator sz:1: finished at 12:32:28 and took 1ms
Trainable Forward pass: started at 12:32:28
loss: 1.6769405603408813
Trainable Forward pass: finished at 12:32:28 and took 21ms


  0%|          | 6/1500 [00:03<11:05,  2.24it/s]

Collator sz:1: started at 12:32:28
Collator sz:1: finished at 12:32:28 and took 1ms
Trainable Forward pass: started at 12:32:28
loss: 1.7159079313278198
Trainable Forward pass: finished at 12:32:28 and took 21ms


  0%|          | 7/1500 [00:04<10:05,  2.47it/s]

Collator sz:1: started at 12:32:28
Collator sz:1: finished at 12:32:28 and took 1ms
Trainable Forward pass: started at 12:32:29
loss: 1.6235859394073486
Trainable Forward pass: finished at 12:32:29 and took 20ms


  1%|          | 8/1500 [00:04<09:33,  2.60it/s]

Collator sz:1: started at 12:32:29
Collator sz:1: finished at 12:32:29 and took 1ms
Trainable Forward pass: started at 12:32:29
loss: 1.5625314712524414
Trainable Forward pass: finished at 12:32:29 and took 21ms


  1%|          | 9/1500 [00:04<09:22,  2.65it/s]

Collator sz:1: started at 12:32:29
Collator sz:1: finished at 12:32:29 and took 1ms
Trainable Forward pass: started at 12:32:29
loss: 1.5953996181488037
Trainable Forward pass: finished at 12:32:29 and took 20ms


  1%|          | 10/1500 [00:05<09:50,  2.52it/s]

{'loss': 1.6349, 'learning_rate': 6.666666666666667e-06, 'epoch': 0.67}
Collator sz:1: started at 12:32:30
Collator sz:1: finished at 12:32:30 and took 1ms
Trainable Forward pass: started at 12:32:30
loss: 1.54945969581604
Trainable Forward pass: finished at 12:32:30 and took 19ms


  1%|          | 11/1500 [00:05<09:43,  2.55it/s]

Collator sz:1: started at 12:32:30
Collator sz:1: finished at 12:32:30 and took 1ms
Trainable Forward pass: started at 12:32:30
loss: 1.5813130140304565
Trainable Forward pass: finished at 12:32:30 and took 19ms


  1%|          | 12/1500 [00:05<09:28,  2.62it/s]

Collator sz:1: started at 12:32:30
Collator sz:1: finished at 12:32:30 and took 1ms
Trainable Forward pass: started at 12:32:30
loss: 1.663477897644043
Trainable Forward pass: finished at 12:32:30 and took 20ms


  1%|          | 13/1500 [00:06<09:22,  2.64it/s]

Collator sz:1: started at 12:32:31
Collator sz:1: finished at 12:32:31 and took 1ms
Trainable Forward pass: started at 12:32:31
loss: 1.601077675819397
Trainable Forward pass: finished at 12:32:31 and took 20ms


  1%|          | 14/1500 [00:06<09:15,  2.68it/s]

Trainable Forward pass: started at 12:32:31
loss: 1.58486008644104
Trainable Forward pass: finished at 12:32:31 and took 20ms


  1%|          | 15/1500 [00:06<08:58,  2.76it/s]

{'loss': 1.596, 'learning_rate': 1e-05, 'epoch': 1.0}
Collator sz:1: started at 12:32:31
Collator sz:1: finished at 12:32:31 and took 1ms
Collator sz:1: started at 12:32:31
Collator sz:1: finished at 12:32:31 and took 1ms
Trainable Forward pass: started at 12:32:32
loss: 1.5787230730056763
Trainable Forward pass: finished at 12:32:32 and took 21ms


  1%|          | 16/1500 [00:07<09:05,  2.72it/s]

Collator sz:1: started at 12:32:32
Collator sz:1: finished at 12:32:32 and took 1ms
Trainable Forward pass: started at 12:32:32
loss: 1.63589608669281
Trainable Forward pass: finished at 12:32:32 and took 20ms


  1%|          | 17/1500 [00:07<08:58,  2.75it/s]

Collator sz:1: started at 12:32:32
Collator sz:1: finished at 12:32:32 and took 1ms
Trainable Forward pass: started at 12:32:32
loss: 1.7696762084960938
Trainable Forward pass: finished at 12:32:32 and took 19ms


  1%|          | 18/1500 [00:08<09:18,  2.65it/s]

Collator sz:1: started at 12:32:33
Collator sz:1: finished at 12:32:33 and took 1ms
Trainable Forward pass: started at 12:32:33
loss: 1.5661016702651978
Trainable Forward pass: finished at 12:32:33 and took 21ms


  1%|▏         | 19/1500 [00:08<08:57,  2.76it/s]

Collator sz:1: started at 12:32:33
Collator sz:1: finished at 12:32:33 and took 1ms
Trainable Forward pass: started at 12:32:33
loss: 1.5201486349105835
Trainable Forward pass: finished at 12:32:33 and took 21ms


  1%|▏         | 20/1500 [00:08<08:51,  2.79it/s]

{'loss': 1.6141, 'learning_rate': 1.3333333333333333e-05, 'epoch': 1.33}
Collator sz:1: started at 12:32:33
Collator sz:1: finished at 12:32:33 and took 1ms
Trainable Forward pass: started at 12:32:33
loss: 1.598671317100525
Trainable Forward pass: finished at 12:32:33 and took 19ms


  1%|▏         | 21/1500 [00:09<08:46,  2.81it/s]

Collator sz:1: started at 12:32:34
Collator sz:1: finished at 12:32:34 and took 1ms
Trainable Forward pass: started at 12:32:34
loss: 1.6489006280899048
Trainable Forward pass: finished at 12:32:34 and took 19ms


  1%|▏         | 22/1500 [00:09<08:39,  2.84it/s]

Collator sz:1: started at 12:32:34
Collator sz:1: finished at 12:32:34 and took 1ms
Trainable Forward pass: started at 12:32:34
loss: 1.6202372312545776
Trainable Forward pass: finished at 12:32:34 and took 21ms


  2%|▏         | 23/1500 [00:09<08:45,  2.81it/s]

Collator sz:1: started at 12:32:34
Collator sz:1: finished at 12:32:34 and took 1ms
Trainable Forward pass: started at 12:32:34
loss: 1.538067102432251
Trainable Forward pass: finished at 12:32:34 and took 22ms


  2%|▏         | 24/1500 [00:10<08:51,  2.78it/s]

Collator sz:1: started at 12:32:35
Collator sz:1: finished at 12:32:35 and took 1ms
Trainable Forward pass: started at 12:32:35
loss: 1.566703200340271
Trainable Forward pass: finished at 12:32:35 and took 21ms


  2%|▏         | 25/1500 [00:10<09:09,  2.69it/s]

{'loss': 1.5945, 'learning_rate': 1.6666666666666667e-05, 'epoch': 1.67}
Collator sz:1: started at 12:32:35
Collator sz:1: finished at 12:32:35 and took 1ms
Trainable Forward pass: started at 12:32:35
loss: 1.4728535413742065
Trainable Forward pass: finished at 12:32:35 and took 20ms


  2%|▏         | 26/1500 [00:10<09:03,  2.71it/s]

Collator sz:1: started at 12:32:35
Collator sz:1: finished at 12:32:35 and took 1ms
Trainable Forward pass: started at 12:32:36
loss: 1.3268377780914307
Trainable Forward pass: finished at 12:32:36 and took 21ms


  2%|▏         | 27/1500 [00:11<08:57,  2.74it/s]

Collator sz:1: started at 12:32:36
Collator sz:1: finished at 12:32:36 and took 1ms
Trainable Forward pass: started at 12:32:36
loss: 1.4907573461532593
Trainable Forward pass: finished at 12:32:36 and took 20ms


  2%|▏         | 28/1500 [00:11<08:51,  2.77it/s]

Collator sz:1: started at 12:32:36
Collator sz:1: finished at 12:32:36 and took 1ms
Trainable Forward pass: started at 12:32:36
loss: 1.515650987625122
Trainable Forward pass: finished at 12:32:36 and took 21ms


  2%|▏         | 29/1500 [00:12<08:46,  2.80it/s]

Trainable Forward pass: started at 12:32:37
loss: 1.397913932800293
Trainable Forward pass: finished at 12:32:37 and took 20ms


  2%|▏         | 30/1500 [00:12<08:38,  2.83it/s]

{'loss': 1.4408, 'learning_rate': 2e-05, 'epoch': 2.0}


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-30)... Done. 0.2s


Collator sz:1: started at 12:32:38
Collator sz:1: finished at 12:32:38 and took 3ms
Collator sz:1: started at 12:32:38
Collator sz:1: finished at 12:32:38 and took 1ms
Trainable Forward pass: started at 12:32:38
loss: 1.476318597793579
Trainable Forward pass: finished at 12:32:38 and took 20ms


  2%|▏         | 31/1500 [00:14<20:01,  1.22it/s]

Collator sz:1: started at 12:32:39
Collator sz:1: finished at 12:32:39 and took 1ms
Trainable Forward pass: started at 12:32:39
loss: 1.5301636457443237
Trainable Forward pass: finished at 12:32:39 and took 21ms


  2%|▏         | 32/1500 [00:14<17:04,  1.43it/s]

Collator sz:1: started at 12:32:39
Collator sz:1: finished at 12:32:39 and took 1ms
Trainable Forward pass: started at 12:32:39
loss: 1.5929062366485596
Trainable Forward pass: finished at 12:32:39 and took 21ms


  2%|▏         | 33/1500 [00:15<14:44,  1.66it/s]

Collator sz:1: started at 12:32:40
Collator sz:1: finished at 12:32:40 and took 1ms
Trainable Forward pass: started at 12:32:40
loss: 1.512257695198059
Trainable Forward pass: finished at 12:32:40 and took 20ms


  2%|▏         | 34/1500 [00:15<13:08,  1.86it/s]

Collator sz:1: started at 12:32:40
Collator sz:1: finished at 12:32:40 and took 1ms
Trainable Forward pass: started at 12:32:40
loss: 1.4689832925796509
Trainable Forward pass: finished at 12:32:40 and took 22ms


  2%|▏         | 35/1500 [00:15<12:09,  2.01it/s]

{'loss': 1.5161, 'learning_rate': 2.3333333333333336e-05, 'epoch': 2.33}
Collator sz:1: started at 12:32:40
Collator sz:1: finished at 12:32:40 and took 1ms
Trainable Forward pass: started at 12:32:40
loss: 1.552587866783142
Trainable Forward pass: finished at 12:32:40 and took 20ms


  2%|▏         | 36/1500 [00:16<11:13,  2.17it/s]

Collator sz:1: started at 12:32:41
Collator sz:1: finished at 12:32:41 and took 1ms
Trainable Forward pass: started at 12:32:41
loss: 1.4916743040084839
Trainable Forward pass: finished at 12:32:41 and took 19ms


  2%|▏         | 37/1500 [00:16<10:29,  2.32it/s]

Collator sz:1: started at 12:32:41
Collator sz:1: finished at 12:32:41 and took 1ms
Trainable Forward pass: started at 12:32:41
loss: 1.5712090730667114
Trainable Forward pass: finished at 12:32:41 and took 20ms


  3%|▎         | 38/1500 [00:16<09:53,  2.46it/s]

Collator sz:1: started at 12:32:41
Collator sz:1: finished at 12:32:41 and took 1ms
Trainable Forward pass: started at 12:32:41
loss: 1.4545549154281616
Trainable Forward pass: finished at 12:32:42 and took 19ms


  3%|▎         | 39/1500 [00:17<09:31,  2.56it/s]

Collator sz:1: started at 12:32:42
Collator sz:1: finished at 12:32:42 and took 1ms
Trainable Forward pass: started at 12:32:42
loss: 1.4468265771865845
Trainable Forward pass: finished at 12:32:42 and took 20ms


  3%|▎         | 40/1500 [00:17<09:50,  2.47it/s]

{'loss': 1.5034, 'learning_rate': 2.6666666666666667e-05, 'epoch': 2.67}
Collator sz:1: started at 12:32:42
Collator sz:1: finished at 12:32:42 and took 1ms
Trainable Forward pass: started at 12:32:42
loss: 1.8253564834594727
Trainable Forward pass: finished at 12:32:42 and took 24ms


  3%|▎         | 41/1500 [00:18<09:58,  2.44it/s]

Collator sz:1: started at 12:32:43
Collator sz:1: finished at 12:32:43 and took 1ms
Trainable Forward pass: started at 12:32:43
loss: 1.5075757503509521
Trainable Forward pass: finished at 12:32:43 and took 19ms


  3%|▎         | 42/1500 [00:18<09:34,  2.54it/s]

Collator sz:1: started at 12:32:43
Collator sz:1: finished at 12:32:43 and took 1ms
Trainable Forward pass: started at 12:32:43
loss: 1.3516607284545898
Trainable Forward pass: finished at 12:32:43 and took 22ms


  3%|▎         | 43/1500 [00:18<09:13,  2.63it/s]

Collator sz:1: started at 12:32:43
Collator sz:1: finished at 12:32:43 and took 1ms
Trainable Forward pass: started at 12:32:43
loss: 1.292670726776123
Trainable Forward pass: finished at 12:32:43 and took 21ms


  3%|▎         | 44/1500 [00:19<09:11,  2.64it/s]

Trainable Forward pass: started at 12:32:44
loss: 1.3786541223526
Trainable Forward pass: finished at 12:32:44 and took 19ms


  3%|▎         | 45/1500 [00:19<08:57,  2.71it/s]

{'loss': 1.4712, 'learning_rate': 3e-05, 'epoch': 3.0}
Collator sz:1: started at 12:32:44
Collator sz:1: finished at 12:32:44 and took 1ms
Collator sz:1: started at 12:32:44
Collator sz:1: finished at 12:32:44 and took 5ms
Trainable Forward pass: started at 12:32:44
loss: 1.6032527685165405
Trainable Forward pass: finished at 12:32:44 and took 21ms


  3%|▎         | 46/1500 [00:19<09:00,  2.69it/s]

Collator sz:1: started at 12:32:44
Collator sz:1: finished at 12:32:44 and took 1ms
Trainable Forward pass: started at 12:32:44
loss: 1.3370803594589233
Trainable Forward pass: finished at 12:32:45 and took 20ms


  3%|▎         | 47/1500 [00:20<08:53,  2.72it/s]

Collator sz:1: started at 12:32:45
Collator sz:1: finished at 12:32:45 and took 1ms
Trainable Forward pass: started at 12:32:45
loss: 1.5976251363754272
Trainable Forward pass: finished at 12:32:45 and took 19ms


  3%|▎         | 48/1500 [00:20<09:19,  2.59it/s]

Collator sz:1: started at 12:32:45
Collator sz:1: finished at 12:32:45 and took 2ms
Trainable Forward pass: started at 12:32:45
loss: 1.6022061109542847
Trainable Forward pass: finished at 12:32:45 and took 19ms


  3%|▎         | 49/1500 [00:21<09:09,  2.64it/s]

Collator sz:1: started at 12:32:46
Collator sz:1: finished at 12:32:46 and took 1ms
Trainable Forward pass: started at 12:32:46
loss: 1.3489559888839722
Trainable Forward pass: finished at 12:32:46 and took 19ms


  3%|▎         | 50/1500 [00:21<08:58,  2.69it/s]

{'loss': 1.4978, 'learning_rate': 3.3333333333333335e-05, 'epoch': 3.33}
Collator sz:1: started at 12:32:46
Collator sz:1: finished at 12:32:46 and took 1ms
Trainable Forward pass: started at 12:32:46
loss: 1.3601080179214478
Trainable Forward pass: finished at 12:32:46 and took 19ms


  3%|▎         | 51/1500 [00:21<08:54,  2.71it/s]

Collator sz:1: started at 12:32:46
Collator sz:1: finished at 12:32:46 and took 1ms
Trainable Forward pass: started at 12:32:46
loss: 1.2271019220352173
Trainable Forward pass: finished at 12:32:46 and took 19ms


  3%|▎         | 52/1500 [00:22<08:46,  2.75it/s]

Collator sz:1: started at 12:32:47
Collator sz:1: finished at 12:32:47 and took 1ms
Trainable Forward pass: started at 12:32:47
loss: 1.518980622291565
Trainable Forward pass: finished at 12:32:47 and took 21ms


  4%|▎         | 53/1500 [00:22<08:40,  2.78it/s]

Collator sz:1: started at 12:32:47
Collator sz:1: finished at 12:32:47 and took 1ms
Trainable Forward pass: started at 12:32:47
loss: 1.2056242227554321
Trainable Forward pass: finished at 12:32:47 and took 20ms


  4%|▎         | 54/1500 [00:22<08:32,  2.82it/s]

Collator sz:1: started at 12:32:47
Collator sz:1: finished at 12:32:47 and took 1ms
Trainable Forward pass: started at 12:32:47
loss: 1.2630228996276855
Trainable Forward pass: finished at 12:32:47 and took 19ms


  4%|▎         | 55/1500 [00:23<08:34,  2.81it/s]

{'loss': 1.315, 'learning_rate': 3.6666666666666666e-05, 'epoch': 3.67}
Collator sz:1: started at 12:32:48
Collator sz:1: finished at 12:32:48 and took 1ms
Trainable Forward pass: started at 12:32:48
loss: 1.4154913425445557
Trainable Forward pass: finished at 12:32:48 and took 19ms


  4%|▎         | 56/1500 [00:23<09:08,  2.63it/s]

Collator sz:1: started at 12:32:48
Collator sz:1: finished at 12:32:48 and took 1ms
Trainable Forward pass: started at 12:32:48
loss: 1.6266533136367798
Trainable Forward pass: finished at 12:32:48 and took 19ms


  4%|▍         | 57/1500 [00:24<08:57,  2.68it/s]

Collator sz:1: started at 12:32:48
Collator sz:1: finished at 12:32:48 and took 1ms
Trainable Forward pass: started at 12:32:49
loss: 1.3279751539230347
Trainable Forward pass: finished at 12:32:49 and took 20ms


  4%|▍         | 58/1500 [00:24<08:48,  2.73it/s]

Collator sz:1: started at 12:32:49
Collator sz:1: finished at 12:32:49 and took 1ms
Trainable Forward pass: started at 12:32:49
loss: 1.3094629049301147
Trainable Forward pass: finished at 12:32:49 and took 20ms


  4%|▍         | 59/1500 [00:24<08:43,  2.75it/s]

Trainable Forward pass: started at 12:32:49
loss: 1.260124921798706
Trainable Forward pass: finished at 12:32:49 and took 19ms


  4%|▍         | 60/1500 [00:25<08:41,  2.76it/s]

{'loss': 1.3879, 'learning_rate': 4e-05, 'epoch': 4.0}


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-60)... Done. 0.3s


Collator sz:1: started at 12:32:51
Collator sz:1: finished at 12:32:51 and took 1ms
Collator sz:1: started at 12:32:51
Collator sz:1: finished at 12:32:51 and took 1ms
Trainable Forward pass: started at 12:32:51
loss: 1.307854175567627
Trainable Forward pass: finished at 12:32:51 and took 20ms


  4%|▍         | 61/1500 [00:27<20:16,  1.18it/s]

Collator sz:1: started at 12:32:51
Collator sz:1: finished at 12:32:51 and took 1ms
Trainable Forward pass: started at 12:32:52
loss: 1.26534104347229
Trainable Forward pass: finished at 12:32:52 and took 21ms


  4%|▍         | 62/1500 [00:27<16:27,  1.46it/s]

Collator sz:1: started at 12:32:52
Collator sz:1: finished at 12:32:52 and took 1ms
Trainable Forward pass: started at 12:32:52
loss: 1.3363059759140015
Trainable Forward pass: finished at 12:32:52 and took 44ms


  4%|▍         | 63/1500 [00:27<14:24,  1.66it/s]

Collator sz:1: started at 12:32:52
Collator sz:1: finished at 12:32:52 and took 1ms
Trainable Forward pass: started at 12:32:52
loss: 1.382799744606018
Trainable Forward pass: finished at 12:32:52 and took 22ms


  4%|▍         | 64/1500 [00:28<12:42,  1.88it/s]

Collator sz:1: started at 12:32:53
Collator sz:1: finished at 12:32:53 and took 1ms
Trainable Forward pass: started at 12:32:53
loss: 1.9589922428131104
Trainable Forward pass: finished at 12:32:53 and took 22ms


  4%|▍         | 65/1500 [00:28<11:29,  2.08it/s]

{'loss': 1.4503, 'learning_rate': 4.3333333333333334e-05, 'epoch': 4.33}
Collator sz:1: started at 12:32:53
Collator sz:1: finished at 12:32:53 and took 1ms
Trainable Forward pass: started at 12:32:53
loss: 1.4045876264572144
Trainable Forward pass: finished at 12:32:53 and took 20ms


  4%|▍         | 66/1500 [00:28<10:37,  2.25it/s]

Collator sz:1: started at 12:32:53
Collator sz:1: finished at 12:32:53 and took 1ms
Trainable Forward pass: started at 12:32:53
loss: 1.383438229560852
Trainable Forward pass: finished at 12:32:53 and took 20ms


  4%|▍         | 67/1500 [00:29<09:51,  2.42it/s]

Collator sz:1: started at 12:32:54
Collator sz:1: finished at 12:32:54 and took 1ms
Trainable Forward pass: started at 12:32:54
loss: 1.1762030124664307
Trainable Forward pass: finished at 12:32:54 and took 19ms


  5%|▍         | 68/1500 [00:29<09:28,  2.52it/s]

Collator sz:1: started at 12:32:54
Collator sz:1: finished at 12:32:54 and took 1ms
Trainable Forward pass: started at 12:32:54
loss: 1.3797554969787598
Trainable Forward pass: finished at 12:32:54 and took 19ms


  5%|▍         | 69/1500 [00:29<09:03,  2.63it/s]

Collator sz:1: started at 12:32:54
Collator sz:1: finished at 12:32:54 and took 1ms
Trainable Forward pass: started at 12:32:54
loss: 1.573770523071289
Trainable Forward pass: finished at 12:32:54 and took 20ms


  5%|▍         | 70/1500 [00:30<08:52,  2.68it/s]

{'loss': 1.3836, 'learning_rate': 4.666666666666667e-05, 'epoch': 4.67}
Collator sz:1: started at 12:32:55
Collator sz:1: finished at 12:32:55 and took 1ms
Trainable Forward pass: started at 12:32:55
loss: 1.1850768327713013
Trainable Forward pass: finished at 12:32:55 and took 20ms


  5%|▍         | 71/1500 [00:30<09:22,  2.54it/s]

Collator sz:1: started at 12:32:55
Collator sz:1: finished at 12:32:55 and took 1ms
Trainable Forward pass: started at 12:32:55
loss: 1.2807880640029907
Trainable Forward pass: finished at 12:32:55 and took 20ms


  5%|▍         | 72/1500 [00:31<09:06,  2.61it/s]

Collator sz:1: started at 12:32:56
Collator sz:1: finished at 12:32:56 and took 1ms
Trainable Forward pass: started at 12:32:56
loss: 1.233148217201233
Trainable Forward pass: finished at 12:32:56 and took 19ms


  5%|▍         | 73/1500 [00:31<09:00,  2.64it/s]

Collator sz:1: started at 12:32:56
Collator sz:1: finished at 12:32:56 and took 1ms
Trainable Forward pass: started at 12:32:56
loss: 1.4500131607055664
Trainable Forward pass: finished at 12:32:56 and took 19ms


  5%|▍         | 74/1500 [00:31<08:45,  2.71it/s]

Trainable Forward pass: started at 12:32:56
loss: 1.2411035299301147
Trainable Forward pass: finished at 12:32:56 and took 22ms


  5%|▌         | 75/1500 [00:32<08:36,  2.76it/s]

{'loss': 1.278, 'learning_rate': 5e-05, 'epoch': 5.0}
Collator sz:1: started at 12:32:57
Collator sz:1: finished at 12:32:57 and took 2ms
Collator sz:1: started at 12:32:57
Collator sz:1: finished at 12:32:57 and took 1ms
Trainable Forward pass: started at 12:32:57
loss: 1.2934502363204956
Trainable Forward pass: finished at 12:32:57 and took 17ms


  5%|▌         | 76/1500 [00:32<08:23,  2.83it/s]

Collator sz:1: started at 12:32:57
Collator sz:1: finished at 12:32:57 and took 2ms
Trainable Forward pass: started at 12:32:57
loss: 1.3198665380477905
Trainable Forward pass: finished at 12:32:57 and took 20ms


  5%|▌         | 77/1500 [00:32<08:10,  2.90it/s]

Collator sz:1: started at 12:32:57
Collator sz:1: finished at 12:32:57 and took 1ms
Trainable Forward pass: started at 12:32:57
loss: 1.2301809787750244
Trainable Forward pass: finished at 12:32:57 and took 20ms


  5%|▌         | 78/1500 [00:33<08:02,  2.95it/s]

Collator sz:1: started at 12:32:58
Collator sz:1: finished at 12:32:58 and took 1ms
Trainable Forward pass: started at 12:32:58
loss: 1.133134126663208
Trainable Forward pass: finished at 12:32:58 and took 19ms


  5%|▌         | 79/1500 [00:33<08:17,  2.86it/s]

Collator sz:1: started at 12:32:58
Collator sz:1: finished at 12:32:58 and took 1ms
Trainable Forward pass: started at 12:32:58
loss: 1.4344013929367065
Trainable Forward pass: finished at 12:32:58 and took 21ms


  5%|▌         | 80/1500 [00:33<08:27,  2.80it/s]

{'loss': 1.2822, 'learning_rate': 5.333333333333333e-05, 'epoch': 5.33}
Collator sz:1: started at 12:32:58
Collator sz:1: finished at 12:32:58 and took 1ms
Trainable Forward pass: started at 12:32:58
loss: 1.0338090658187866
Trainable Forward pass: finished at 12:32:58 and took 20ms


  5%|▌         | 81/1500 [00:34<08:34,  2.76it/s]

Collator sz:1: started at 12:32:59
Collator sz:1: finished at 12:32:59 and took 1ms
Trainable Forward pass: started at 12:32:59
loss: 1.3116143941879272
Trainable Forward pass: finished at 12:32:59 and took 20ms


  5%|▌         | 82/1500 [00:34<08:39,  2.73it/s]

Collator sz:1: started at 12:32:59
Collator sz:1: finished at 12:32:59 and took 1ms
Trainable Forward pass: started at 12:32:59
loss: 1.2601479291915894
Trainable Forward pass: finished at 12:32:59 and took 21ms


  6%|▌         | 83/1500 [00:34<08:43,  2.71it/s]

Collator sz:1: started at 12:32:59
Collator sz:1: finished at 12:32:59 and took 2ms
Trainable Forward pass: started at 12:33:00
loss: 1.172621488571167
Trainable Forward pass: finished at 12:33:00 and took 20ms


  6%|▌         | 84/1500 [00:35<08:49,  2.67it/s]

Collator sz:1: started at 12:33:00
Collator sz:1: finished at 12:33:00 and took 1ms
Trainable Forward pass: started at 12:33:00
loss: 1.0856399536132812
Trainable Forward pass: finished at 12:33:00 and took 19ms


  6%|▌         | 85/1500 [00:35<08:40,  2.72it/s]

{'loss': 1.1728, 'learning_rate': 5.666666666666667e-05, 'epoch': 5.67}
Collator sz:1: started at 12:33:00
Collator sz:1: finished at 12:33:00 and took 1ms
Trainable Forward pass: started at 12:33:00
loss: 1.2920852899551392
Trainable Forward pass: finished at 12:33:00 and took 19ms


  6%|▌         | 86/1500 [00:36<08:45,  2.69it/s]

Collator sz:1: started at 12:33:01
Collator sz:1: finished at 12:33:01 and took 1ms
Trainable Forward pass: started at 12:33:01
loss: 0.9838230013847351
Trainable Forward pass: finished at 12:33:01 and took 20ms


  6%|▌         | 87/1500 [00:36<08:52,  2.65it/s]

Collator sz:1: started at 12:33:01
Collator sz:1: finished at 12:33:01 and took 1ms
Trainable Forward pass: started at 12:33:01
loss: 1.1814709901809692
Trainable Forward pass: finished at 12:33:01 and took 18ms


  6%|▌         | 88/1500 [00:36<08:54,  2.64it/s]

Collator sz:1: started at 12:33:01
Collator sz:1: finished at 12:33:01 and took 1ms
Trainable Forward pass: started at 12:33:01
loss: 1.4173166751861572
Trainable Forward pass: finished at 12:33:01 and took 21ms


  6%|▌         | 89/1500 [00:37<08:56,  2.63it/s]

Trainable Forward pass: started at 12:33:02
loss: 1.2503912448883057
Trainable Forward pass: finished at 12:33:02 and took 18ms


  6%|▌         | 90/1500 [00:37<08:55,  2.63it/s]

{'loss': 1.225, 'learning_rate': 6e-05, 'epoch': 6.0}


KeyboardInterrupt: 

In [None]:
#play
import matplotlib.pyplot as plt
from IPython.display import display as ipy_display, clear_output
#import gymnasium as gym
# build the environment
max_ep_len = 1000
device = 'cpu'
model = model.to('cpu')
scale = 1000.0  # normalization for rewards/returns
TARGET_RETURN = 900 / scale  # evaluation is conditioned on a return of 12000, scaled accordingly

env =  gym.make('CarRacing-v2', render_mode='rgb_array', continuous=False) #, 

state_dim = 96*96*3
act_dim = 1
# Create the decision transformer model

# Interact with the environment and create a video
episode_return, episode_length = 0, 0
[state, _] = env.reset()
state = prepare_observation_array(state)
target_return = torch.tensor(TARGET_RETURN, device=device, dtype=torch.float32).reshape(1, 1)
states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32)
actions = torch.zeros((0, act_dim),  device=device, dtype=torch.long)
rewards = torch.zeros(0, device=device, dtype=torch.float32)
print_every = 10
iter = 0

timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)
for t in range(max_ep_len):
    iter += 1
    actions = torch.cat([actions, torch.zeros((1, act_dim), dtype=torch.long,  device=device)], dim=0)
    rewards = torch.cat([rewards, torch.zeros(1, device=device)])

    action = get_action(
        model,
        states,
        actions,
        rewards,
        target_return,
        timesteps,
    )
    
    action =   torch.argmax(action).item() # action.detach().cpu().numpy()
    
    actions[-1] = torch.tensor(action, dtype=torch.long) 

    state, reward, done, _, _ = env.step(action)
    
    if iter % print_every ==0:
      image = env.render()
      clear_output(wait=True)
      plt.imshow(image)
      plt.axis('off')  # Hide the axis
      display(plt.gcf())
    
    

    state = prepare_observation_array(state)
    cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim)
    states = torch.cat([states, cur_state], dim=0)
    rewards[-1] = reward

    pred_return = target_return[0, -1] - (reward / scale)
    target_return = torch.cat([target_return, pred_return.reshape(1, 1)], dim=1)
    timesteps = torch.cat([timesteps, torch.ones((1, 1), device=device, dtype=torch.long) * (t + 1)], dim=1)

    episode_return += reward
    episode_length += 1

    if done:
        break