### Introduction

This notebook is intended to train a Decision Transformer using offline data gathered from exploring the CarRacing-v2 environment with a pre-trained DQN model.

### Install initial environment in Google Colab

In [1]:
import sys
import os

if 'google.colab' in sys.modules:
  if not os.path.exists('/content/.already_installed'):
    !git clone https://github.com/YakivGalkin/cnn_decision_transformer
    !apt-get install -y swig
    !pip install -r cnn_decision_transformer/requirements.txt
    with open('/content/.already_installed', 'w') as f:
        f.write('done')
  %cd /content/cnn_decision_transformer

Cloning into 'cnn_decision_transformer'...
remote: Enumerating objects: 53, done.[K
remote: Counting objects: 100% (53/53), done.[K
remote: Compressing objects: 100% (43/43), done.[K
remote: Total 53 (delta 19), reused 39 (delta 8), pack-reused 0[K
Receiving objects: 100% (53/53), 1.28 MiB | 24.31 MiB/s, done.
Resolving deltas: 100% (19/19), done.
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
  swig4.0
Suggested packages:
  swig-doc swig-examples swig4.0-examples swig4.0-doc
The following NEW packages will be installed:
  swig swig4.0
0 upgraded, 2 newly installed, 0 to remove and 19 not upgraded.
Need to get 1,116 kB of archives.
After this operation, 5,542 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 swig4.0 amd64 4.0.2-1ubuntu1 [1,110 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/universe amd64 swig all 4.0.2-1ubuntu1

### Load Dataset

In [2]:
#car_racing_15_100
#offline_car_racing_150_1000

import utils.storage as storage
features = storage.load_dataset('offline_car_racing_150_1000')
print(len(features["observations"]))

Downloading file from https://storage.googleapis.com/yakiv-dt-public/datasets/offline_car_racing_150_1000.hdf5 to ./downloaded_datasets/offline_car_racing_150_1000.hdf5
Download complete.
150


In [3]:
import gymnasium as gym
env =  gym.make('CarRacing-v2', continuous=False) #, render_mode='human'

In [8]:
from dataclasses import asdict, dataclass
import wandb
import os


@dataclass
class TrainConfig:
    # WANDB CONFIG
    wandb_id: str = "dt_25"
    wandb_name: str = "DT_25"
    model_save_name: str = "DT_MODEL_25"
    saved_model_version: str = "latest"
    save_steps: int = 10000

    # TRAINING DATA CONFIG
    num_train_epochs: int = 100
    max_ep_len: int = 1000
    max_length: int = 10
    rtg_gamma: float = 1.0

    prefix: str = 'DT'
    log_interval: int = 5
    save_steps: int = 30
    per_device_train_batch_size: int = 1
    learning_rate: float = 0.0001
    weight_decay: float = 0.0001
    warmup_ratio: float = 0.1
    max_grad_norm: float = 0.25


trainConfig = TrainConfig()

os.environ["WANDB_DISABLED"] = "false"
os.environ['WANDB_NOTEBOOK_NAME'] = 'DT_train.ipynb'
os.environ["WANDB_LOG_MODEL"] = "checkpoint"

wandb.login(key="f060d3284088ffaf4624e2de8b236f39711a99a2") # move to .env!
wandb.init(resume=trainConfig.wandb_id,
           name = trainConfig.wandb_name,
           mode="online",
           entity="yakiv",
            project="CarRacingDT",
            #resume= "allow"
            config=asdict(trainConfig)
           )




VBox(children=(Label(value='2062.665 MB of 2062.665 MB uploaded (0.198 MB deduped)\r'), FloatProgress(value=1.…

0,1
train/epoch,▁▁▂▂▂▂▃▃▃▃▃▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇███▁▂▂▂▂▂▃▁▁▁▂
train/global_step,▁▁▂▂▂▂▃▃▃▃▃▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇███▁▂▂▂▂▂▃▁▁▁▂
train/learning_rate,▁▁▂▂▂▂▃▃▃▃▃▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇███▁▂▂▂▂▂▃▁▁▁▂
train/loss,▇█▆▆▆▅▅▄▅▅▄▄▄▅▃▃▄▃▂▂▂▃▂▂▂▃▁▂▂▇▆▅▆▅▄▄█▇█▆

0,1
train/epoch,0.8
train/global_step,120.0
train/learning_rate,1e-05
train/loss,1.532


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112740255555309, max=1.0…

### Train

In [9]:

from visual_decision_transformer.visual_decision_transformer_trainable import VisualDecisionTransformerGymDataCollator, TrainableVisualDecisionTransformer
from visual_decision_transformer.configuration import DecisionTransformerConfig
from utils.dataset_wrappers import DummyDataset
from utils.dataset_wrappers import CarRacingFeatureDataset
from transformers import Trainer, TrainingArguments

feature_dataset = CarRacingFeatureDataset(src=features)
collator = VisualDecisionTransformerGymDataCollator(feature_dataset, max_len=trainConfig.max_length,   max_ep_len=trainConfig.max_ep_len,)

dt_config = DecisionTransformerConfig(state_dim=collator.state_dim, act_dim=collator.act_dim,
                                      max_length = trainConfig.max_length,
                                      max_ep_len = trainConfig.max_ep_len,
                                      )

model = TrainableVisualDecisionTransformer(dt_config)


training_args = TrainingArguments(
    output_dir="output/",
    report_to="wandb",
    save_steps= trainConfig.save_steps,
    remove_unused_columns=False,
    optim="adamw_torch",
    num_train_epochs=trainConfig.num_train_epochs,
    per_device_train_batch_size= trainConfig.per_device_train_batch_size,
    learning_rate= trainConfig.learning_rate,
    weight_decay= trainConfig.weight_decay,
    warmup_ratio= trainConfig.warmup_ratio,
    max_grad_norm= trainConfig.max_grad_norm,
    logging_steps= trainConfig.log_interval,
)



trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=DummyDataset(len(feature_dataset)), #there is a 'hack'  - no need to pass actual data
    data_collator=collator,

)

trainer.train()


loss: 1.5263149738311768


Step,Training Loss
5,1.5183
10,1.6731
15,1.6638
20,1.5698
25,1.5781
30,1.662
35,1.555
40,1.5789
45,1.5767
50,1.5977


loss: 1.4841265678405762
loss: 1.4726547002792358
loss: 1.430295705795288
loss: 1.6781971454620361
loss: 1.692381501197815
loss: 1.7963730096817017
loss: 1.69246506690979
loss: 1.5979727506637573
loss: 1.5861690044403076
loss: 1.6122392416000366
loss: 1.7260528802871704
loss: 1.578924298286438
loss: 1.8102591037750244
loss: 1.591295599937439
loss: 1.5294163227081299
loss: 1.6025301218032837
loss: 1.5057637691497803
loss: 1.5245697498321533
loss: 1.6867077350616455
loss: 1.6493905782699585
loss: 1.599858283996582
loss: 1.5805222988128662
loss: 1.6042280197143555
loss: 1.4566867351531982
loss: 1.626922607421875
loss: 1.7308788299560547
loss: 1.7103157043457031
loss: 1.5793192386627197
loss: 1.6627591848373413


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-30)... Done. 0.1s


loss: 1.5696275234222412
loss: 1.4377272129058838
loss: 1.5644044876098633
loss: 1.624131441116333
loss: 1.5789417028427124
loss: 1.6229280233383179
loss: 1.5636639595031738
loss: 1.5959123373031616
loss: 1.6173051595687866
loss: 1.4949276447296143
loss: 1.369848370552063
loss: 1.5550800561904907
loss: 1.5491259098052979
loss: 1.7963275909423828
loss: 1.613073706626892
loss: 1.6293083429336548
loss: 1.4958410263061523
loss: 1.6783708333969116
loss: 1.6510636806488037
loss: 1.533854365348816
loss: 1.6599557399749756
loss: 1.5881978273391724
loss: 1.5249254703521729
loss: 1.485162615776062
loss: 1.7259899377822876
loss: 1.5857763290405273
loss: 1.6361608505249023
loss: 1.5330660343170166
loss: 1.588605284690857
loss: 1.6586637496948242


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-60)... Done. 0.1s


loss: 1.6212069988250732
loss: 1.533406376838684
loss: 1.465402364730835
loss: 1.5403659343719482
loss: 1.5860280990600586
loss: 1.494494915008545
loss: 1.5497632026672363
loss: 1.5809857845306396
loss: 1.5117409229278564
loss: 1.498069167137146
loss: 1.6126686334609985
loss: 1.622205138206482
loss: 1.7241230010986328
loss: 1.5349634885787964
loss: 1.4993031024932861
loss: 1.527475118637085
loss: 1.4258577823638916
loss: 1.6695398092269897
loss: 1.5455307960510254
loss: 1.4868814945220947
loss: 1.5305838584899902
loss: 1.5329278707504272
loss: 1.5228084325790405
loss: 1.6043853759765625
loss: 1.491986632347107
loss: 1.4077303409576416
loss: 1.6352697610855103
loss: 1.4658483266830444
loss: 1.405568242073059
loss: 1.3767811059951782


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-90)... Done. 0.1s


loss: 1.4556289911270142
loss: 1.3182940483093262
loss: 1.5275439023971558
loss: 1.604789137840271
loss: 1.5215797424316406
loss: 1.4653016328811646
loss: 1.3787506818771362
loss: 1.259239912033081
loss: 1.4703904390335083
loss: 1.4476935863494873
loss: 1.3318122625350952
loss: 1.5615615844726562
loss: 1.519359827041626
loss: 1.4509456157684326
loss: 1.4437472820281982
loss: 1.5130422115325928
loss: 1.5324429273605347
loss: 1.5853134393692017
loss: 1.3655668497085571
loss: 1.3190215826034546
loss: 1.529413104057312
loss: 1.388573408126831
loss: 1.2872602939605713
loss: 1.4541960954666138
loss: 1.5123573541641235
loss: 1.4340938329696655
loss: 1.4804202318191528
loss: 1.3974424600601196
loss: 1.4323989152908325
loss: 1.4137967824935913


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-120)... Done. 0.1s


loss: 1.316589593887329
loss: 1.4648994207382202
loss: 1.3679029941558838
loss: 1.4647639989852905
loss: 1.1881442070007324
loss: 1.253233551979065
loss: 1.607110619544983
loss: 1.3860037326812744
loss: 1.2359752655029297
loss: 1.4076035022735596
loss: 1.4695355892181396
loss: 1.4350711107254028
loss: 1.4874889850616455
loss: 1.3985174894332886
loss: 1.7107738256454468
loss: 1.2683342695236206
loss: 1.4874727725982666
loss: 1.4001386165618896
loss: 1.515300989151001
loss: 1.4397203922271729
loss: 1.348631501197815
loss: 1.5763051509857178
loss: 1.1948517560958862
loss: 1.713636040687561
loss: 1.4903777837753296
loss: 1.1882679462432861
loss: 1.361457109451294
loss: 1.2471435070037842
loss: 1.4827487468719482
loss: 1.5134947299957275


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-150)... Done. 0.1s


loss: 1.423744559288025
loss: 1.4432096481323242
loss: 1.3875213861465454
loss: 1.4576433897018433
loss: 1.228283405303955
loss: 1.4117640256881714
loss: 1.3903907537460327
loss: 1.4832549095153809
loss: 1.3637592792510986
loss: 1.2673165798187256
loss: 1.3995697498321533
loss: 1.489499568939209
loss: 1.532233476638794
loss: 1.4976232051849365
loss: 1.3121780157089233
loss: 1.3444048166275024
loss: 1.340494155883789
loss: 1.4564319849014282
loss: 1.3636983633041382
loss: 1.4250749349594116
loss: 1.4614956378936768
loss: 1.4100459814071655
loss: 1.3474009037017822
loss: 1.4166133403778076
loss: 1.4191772937774658
loss: 1.4696800708770752
loss: 1.3149757385253906
loss: 1.226745843887329
loss: 1.4199990034103394
loss: 1.4370462894439697


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-180)... Done. 0.1s


loss: 1.2144824266433716
loss: 1.2076414823532104
loss: 1.3988901376724243
loss: 1.25264573097229
loss: 1.1983683109283447
loss: 1.5281083583831787
loss: 1.321560263633728
loss: 1.477105736732483
loss: 1.3370001316070557
loss: 1.3632018566131592
loss: 2.0860562324523926
loss: 1.3927682638168335
loss: 1.425455927848816
loss: 1.2850019931793213
loss: 1.4907419681549072
loss: 1.0961191654205322
loss: 1.2190783023834229
loss: 1.266646146774292
loss: 1.1873366832733154
loss: 1.1772124767303467
loss: 1.3998481035232544
loss: 1.2644765377044678
loss: 1.4838107824325562
loss: 1.408496379852295
loss: 1.4087146520614624
loss: 1.567411184310913
loss: 1.2460334300994873
loss: 1.253469705581665
loss: 1.374891996383667
loss: 1.2997392416000366


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-210)... Done. 0.1s


loss: 1.2489608526229858
loss: 1.3882920742034912
loss: 1.2560653686523438
loss: 1.3426451683044434
loss: 1.340842604637146
loss: 1.0278961658477783
loss: 1.2816932201385498
loss: 1.3585125207901
loss: 1.2695024013519287
loss: 1.6138813495635986
loss: 1.1490800380706787
loss: 1.438361406326294
loss: 1.4037914276123047
loss: 1.2251728773117065
loss: 1.499828577041626
loss: 1.2162772417068481
loss: 1.037192702293396
loss: 1.3652489185333252
loss: 1.2712469100952148
loss: 1.4005746841430664
loss: 1.1482646465301514
loss: 1.7921806573867798
loss: 1.5412708520889282
loss: 1.3287235498428345
loss: 1.2043921947479248
loss: 1.2070438861846924
loss: 1.2296417951583862
loss: 1.3688018321990967
loss: 1.164633870124817
loss: 1.2830424308776855


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-240)... Done. 0.1s


loss: 1.2932978868484497
loss: 1.5303694009780884
loss: 1.1588759422302246
loss: 1.1433179378509521
loss: 1.3306584358215332
loss: 1.5463696718215942
loss: 1.1200969219207764
loss: 1.5588798522949219
loss: 1.3857637643814087
loss: 1.184910535812378
loss: 1.3523457050323486
loss: 1.1781623363494873
loss: 1.4579153060913086
loss: 1.1812975406646729
loss: 1.386177659034729
loss: 1.3562568426132202
loss: 1.357041358947754
loss: 1.3821086883544922
loss: 1.1160554885864258
loss: 1.355372428894043
loss: 1.1186069250106812
loss: 1.4952116012573242
loss: 1.5756583213806152
loss: 1.4077600240707397
loss: 1.3564746379852295
loss: 1.287419080734253
loss: 0.977207362651825
loss: 1.2331740856170654
loss: 1.0541685819625854
loss: 1.1090376377105713


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-270)... Done. 0.1s


loss: 1.2255356311798096
loss: 1.485255479812622
loss: 1.5330393314361572
loss: 1.405729055404663
loss: 1.2353010177612305
loss: 1.0970823764801025
loss: 1.197939395904541
loss: 1.2121648788452148
loss: 1.2722973823547363
loss: 1.401254415512085
loss: 1.2748920917510986
loss: 1.403408408164978
loss: 1.3178088665008545
loss: 1.3204121589660645
loss: 0.840777575969696
loss: 1.4157785177230835
loss: 1.3840306997299194
loss: 1.2564376592636108
loss: 1.0341506004333496
loss: 1.5140401124954224
loss: 1.4137365818023682
loss: 1.1858272552490234
loss: 1.601366639137268
loss: 1.2591127157211304
loss: 1.639735460281372
loss: 1.4820870161056519
loss: 1.2423185110092163
loss: 1.1229467391967773
loss: 1.2313339710235596
loss: 1.4425073862075806


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-300)... Done. 0.1s


loss: 1.2167044878005981
loss: 1.2655994892120361
loss: 1.3200900554656982
loss: 1.1080559492111206
loss: 1.3875036239624023
loss: 0.9904330968856812
loss: 1.5384423732757568
loss: 1.0551705360412598
loss: 1.323557734489441
loss: 1.4126754999160767
loss: 1.1702759265899658
loss: 1.0763219594955444
loss: 1.2157859802246094
loss: 1.0524976253509521
loss: 1.0688531398773193
loss: 1.4630272388458252
loss: 0.9977818727493286
loss: 1.1882622241973877
loss: 1.6547510623931885
loss: 1.3630783557891846
loss: 1.251989722251892
loss: 1.0355912446975708
loss: 1.2228929996490479
loss: 1.1592460870742798
loss: 1.2463374137878418
loss: 1.5762180089950562
loss: 1.2553198337554932
loss: 1.3659486770629883
loss: 1.4778517484664917
loss: 0.9169584512710571


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-330)... Done. 0.1s


loss: 1.2171567678451538
loss: 1.5070229768753052
loss: 1.2995326519012451
loss: 1.1166563034057617
loss: 0.9717451930046082
loss: 1.1504528522491455
loss: 1.1127631664276123
loss: 1.0810024738311768
loss: 1.257811427116394
loss: 1.1507009267807007
loss: 1.4620906114578247
loss: 1.3160854578018188
loss: 1.1874713897705078
loss: 1.0018173456192017
loss: 1.3026978969573975
loss: 1.3095232248306274
loss: 1.7155450582504272
loss: 0.9098175168037415
loss: 1.3259315490722656
loss: 1.2014105319976807
loss: 1.2683601379394531
loss: 1.3020808696746826
loss: 1.384795904159546
loss: 0.7870007753372192
loss: 1.1773266792297363
loss: 1.0081617832183838
loss: 1.1740458011627197
loss: 1.3428308963775635
loss: 1.2419321537017822
loss: 1.6122764348983765


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-360)... Done. 0.1s


loss: 0.9130253791809082
loss: 0.9938253164291382
loss: 1.298465609550476
loss: 1.2983261346817017
loss: 1.2891838550567627
loss: 1.3131202459335327
loss: 1.0284757614135742
loss: 1.2098588943481445
loss: 1.2394075393676758
loss: 1.2930876016616821
loss: 1.3262794017791748
loss: 1.2181966304779053
loss: 1.4018042087554932
loss: 1.021735429763794
loss: 1.419717788696289
loss: 1.2115991115570068
loss: 1.0768835544586182
loss: 1.3012974262237549
loss: 1.283567190170288
loss: 1.2145522832870483
loss: 1.287083387374878
loss: 1.4992549419403076
loss: 1.3665863275527954
loss: 1.033308744430542
loss: 1.4221487045288086
loss: 1.4118983745574951
loss: 0.8146430253982544
loss: 1.2543503046035767
loss: 1.0681672096252441
loss: 1.1536076068878174


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-390)... Done. 0.1s


loss: 1.2570703029632568
loss: 1.0581260919570923
loss: 1.133479356765747
loss: 1.9782979488372803
loss: 1.115085244178772
loss: 1.4811277389526367
loss: 1.236822485923767
loss: 1.2022241353988647
loss: 0.9972896575927734
loss: 1.028801679611206
loss: 1.1812961101531982
loss: 0.9952371716499329
loss: 1.3264214992523193
loss: 0.948111355304718
loss: 1.1870858669281006
loss: 1.1740909814834595
loss: 1.368833065032959
loss: 1.0970739126205444
loss: 1.03931725025177
loss: 1.2132916450500488
loss: 1.168705701828003
loss: 1.155510663986206
loss: 1.1161658763885498
loss: 1.2723946571350098
loss: 1.218064546585083
loss: 1.0684953927993774
loss: 1.1471171379089355
loss: 1.1041361093521118
loss: 1.3451071977615356
loss: 1.3313448429107666


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-420)... Done. 0.1s


loss: 0.7009390592575073
loss: 1.1290242671966553
loss: 1.3014620542526245
loss: 0.9246603846549988
loss: 1.3739144802093506
loss: 0.9016541242599487
loss: 1.3425333499908447
loss: 1.46071195602417
loss: 0.9994630813598633
loss: 1.2905302047729492
loss: 1.4158521890640259
loss: 0.9053484201431274
loss: 1.2345062494277954
loss: 1.2523436546325684
loss: 0.6538457274436951
loss: 1.3577978610992432
loss: 1.1587527990341187
loss: 1.4426392316818237
loss: 1.2762820720672607
loss: 1.145410180091858
loss: 1.165711760520935
loss: 1.2070653438568115
loss: 1.0327802896499634
loss: 1.0582290887832642
loss: 1.0655419826507568
loss: 1.1617660522460938
loss: 1.336775541305542
loss: 0.978219211101532
loss: 0.8672996759414673
loss: 1.240646243095398


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-450)... Done. 0.1s


loss: 1.4658536911010742
loss: 1.203673005104065
loss: 1.4494884014129639
loss: 1.027348279953003
loss: 1.1190015077590942
loss: 1.3385226726531982
loss: 1.042979121208191
loss: 1.278126835823059
loss: 1.1118630170822144
loss: 1.148184061050415
loss: 0.6533803939819336
loss: 0.9667600393295288
loss: 1.4290817975997925
loss: 0.9374706149101257
loss: 1.083430528640747
loss: 0.7252233028411865
loss: 0.7898951768875122
loss: 0.7516877055168152
loss: 1.3353493213653564
loss: 0.7318121194839478
loss: 0.9968441128730774
loss: 1.139695167541504
loss: 1.0005924701690674
loss: 0.8962483406066895
loss: 1.1485344171524048
loss: 1.0321325063705444
loss: 1.1707372665405273
loss: 1.190604329109192
loss: 0.9333917498588562
loss: 0.8876611590385437


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-480)... Done. 0.1s


loss: 1.1981168985366821
loss: 1.353952169418335
loss: 1.1836737394332886
loss: 1.2224276065826416
loss: 1.0472791194915771
loss: 1.440375566482544
loss: 0.5040952563285828
loss: 1.0755130052566528
loss: 1.1382251977920532
loss: 0.8817170858383179
loss: 1.2560704946517944
loss: 1.1586233377456665
loss: 1.108872413635254
loss: 1.0250771045684814
loss: 1.1456632614135742
loss: 0.5137445330619812
loss: 1.039598822593689
loss: 1.3472381830215454
loss: 0.5156819224357605
loss: 0.7693295478820801
loss: 1.1278061866760254
loss: 1.2189667224884033
loss: 0.6347337961196899
loss: 1.325278639793396
loss: 0.9901682138442993
loss: 1.6139274835586548
loss: 1.0296047925949097
loss: 1.5619665384292603
loss: 1.2991729974746704
loss: 1.1623789072036743


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-510)... Done. 0.1s


loss: 0.915573000907898
loss: 1.0621488094329834
loss: 0.568259596824646
loss: 1.0039385557174683
loss: 1.5495221614837646
loss: 0.808304488658905
loss: 0.8966638445854187
loss: 1.461424708366394
loss: 0.8632076978683472
loss: 1.1056636571884155
loss: 1.2270158529281616
loss: 1.1671993732452393
loss: 1.4086534976959229
loss: 0.9017994999885559
loss: 1.1966776847839355
loss: 1.0039966106414795
loss: 0.5040138363838196
loss: 1.1384741067886353
loss: 0.9682450294494629
loss: 1.1268895864486694
loss: 1.1682476997375488
loss: 0.6553233861923218
loss: 0.9788468480110168
loss: 1.432861089706421
loss: 1.1047316789627075
loss: 1.4292280673980713
loss: 1.0639679431915283
loss: 1.314711332321167
loss: 1.1387732028961182
loss: 0.9076946377754211


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-540)... Done. 0.1s


loss: 1.0718388557434082
loss: 1.1434110403060913
loss: 1.3055915832519531
loss: 1.4119517803192139
loss: 0.9312423467636108
loss: 1.0152558088302612
loss: 1.0735257863998413
loss: 1.372573971748352
loss: 1.802408218383789
loss: 1.0181957483291626
loss: 1.2307301759719849
loss: 1.4079511165618896
loss: 1.3117271661758423
loss: 0.7410002946853638
loss: 1.0804834365844727
loss: 1.0688730478286743
loss: 0.7643769979476929
loss: 1.1714988946914673
loss: 1.0572620630264282
loss: 0.7590543627738953
loss: 1.000207781791687
loss: 1.0667221546173096
loss: 0.9009736776351929
loss: 1.1513206958770752
loss: 0.8520674705505371
loss: 1.1383442878723145
loss: 0.4973684847354889
loss: 0.5250109434127808
loss: 1.322300672531128
loss: 1.0158640146255493


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-570)... Done. 0.1s


loss: 1.4151735305786133
loss: 1.0683871507644653
loss: 1.0795347690582275
loss: 0.41758179664611816
loss: 0.9592647552490234
loss: 0.9871221780776978
loss: 0.9444060325622559
loss: 1.2268221378326416
loss: 1.2342634201049805
loss: 0.9276931881904602
loss: 0.8267227411270142
loss: 1.0657122135162354
loss: 1.4903643131256104
loss: 0.994742214679718
loss: 1.3268892765045166
loss: 0.8921924829483032
loss: 0.681280255317688
loss: 0.7833734750747681
loss: 0.8520833849906921
loss: 0.9271188974380493
loss: 1.1256167888641357
loss: 1.177803874015808
loss: 1.1318061351776123
loss: 0.9795125126838684
loss: 1.0315179824829102
loss: 1.0089852809906006
loss: 1.082793951034546
loss: 1.124626874923706
loss: 1.1450412273406982
loss: 1.0502818822860718


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-600)... Done. 0.1s


loss: 0.9412093162536621
loss: 0.5731238126754761
loss: 0.8653100728988647
loss: 0.7528806924819946
loss: 1.4873688220977783
loss: 1.1957752704620361
loss: 1.210485577583313
loss: 0.5115973353385925
loss: 1.4191620349884033
loss: 1.2939410209655762
loss: 1.0407485961914062
loss: 1.0388273000717163
loss: 0.9881659746170044
loss: 1.1531010866165161
loss: 1.1390186548233032
loss: 0.6879371404647827
loss: 0.941256046295166
loss: 0.8219164609909058
loss: 1.0330843925476074
loss: 1.4605991840362549
loss: 0.9684935808181763
loss: 0.7869359254837036
loss: 0.7634395360946655
loss: 1.4356307983398438
loss: 0.8981353044509888
loss: 1.3471314907073975
loss: 0.8444185256958008
loss: 1.16758131980896
loss: 1.0653916597366333
loss: 1.4115127325057983


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-630)... Done. 0.1s


loss: 1.1045902967453003
loss: 0.6811848282814026
loss: 1.3195405006408691
loss: 1.4227796792984009
loss: 1.3097736835479736
loss: 1.189655065536499
loss: 0.7927678823471069
loss: 1.0725414752960205
loss: 1.12399423122406
loss: 0.9421235918998718
loss: 1.1214193105697632
loss: 0.5471965670585632
loss: 1.0168378353118896
loss: 1.0064048767089844
loss: 1.2622578144073486
loss: 1.4214293956756592
loss: 0.7325305938720703
loss: 0.5851391553878784
loss: 0.8879314661026001
loss: 1.295487642288208
loss: 1.0714359283447266
loss: 1.3861486911773682
loss: 1.672156572341919
loss: 0.8953787684440613
loss: 1.0753248929977417
loss: 0.8523876070976257
loss: 0.9133445024490356
loss: 0.7218830585479736
loss: 0.917022705078125
loss: 0.37262824177742004


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-660)... Done. 0.1s


loss: 1.1389015913009644
loss: 1.0010850429534912
loss: 0.979453444480896
loss: 1.327564001083374
loss: 1.3737492561340332
loss: 0.8309088945388794
loss: 0.9021388292312622
loss: 0.7555676698684692
loss: 1.505940556526184
loss: 0.7641984820365906
loss: 1.0194554328918457
loss: 1.0291211605072021
loss: 1.2309038639068604
loss: 1.161354422569275
loss: 1.2664234638214111
loss: 0.8899141550064087
loss: 0.8433625102043152
loss: 0.6637033820152283
loss: 1.5029016733169556
loss: 0.9202810525894165
loss: 1.27507746219635
loss: 1.0091817378997803
loss: 1.4120290279388428
loss: 1.2150981426239014
loss: 1.08895742893219
loss: 1.5368897914886475
loss: 0.8644159436225891
loss: 1.2992531061172485
loss: 1.2041574716567993
loss: 1.083563208580017


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-690)... Done. 0.1s


loss: 1.1855573654174805
loss: 1.6176490783691406
loss: 0.8770340085029602
loss: 1.1582508087158203
loss: 0.9835742115974426
loss: 0.8948556184768677
loss: 1.0096728801727295
loss: 0.8167999982833862
loss: 0.9757403135299683
loss: 1.2257893085479736
loss: 0.9090744853019714
loss: 1.309306025505066
loss: 0.9661598205566406
loss: 0.905725359916687
loss: 0.8954842686653137
loss: 0.923910915851593
loss: 0.8583892583847046
loss: 0.7124377489089966
loss: 0.8245041966438293
loss: 1.409900426864624
loss: 1.1877232789993286
loss: 1.0040465593338013
loss: 1.5494202375411987
loss: 1.4149481058120728
loss: 0.9336031079292297
loss: 1.0185092687606812
loss: 0.8238010406494141
loss: 1.0332125425338745
loss: 0.9576268196105957
loss: 1.9521137475967407


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-720)... Done. 0.1s


loss: 1.2844038009643555
loss: 0.9548861384391785
loss: 0.9944073557853699
loss: 0.9198211431503296
loss: 1.2257719039916992
loss: 0.9218772053718567
loss: 1.0933246612548828
loss: 0.6186025142669678
loss: 0.7070668935775757
loss: 1.0710359811782837
loss: 0.7640336751937866
loss: 1.0956863164901733
loss: 1.0939624309539795
loss: 0.8403739929199219
loss: 0.8747515678405762
loss: 1.2782937288284302
loss: 0.742195725440979
loss: 0.8704301118850708
loss: 0.8489144444465637
loss: 0.7293382883071899
loss: 0.452164888381958
loss: 1.038073182106018
loss: 1.0800076723098755
loss: 1.028207540512085
loss: 0.8782180547714233
loss: 0.9359514117240906
loss: 0.8947831988334656
loss: 1.1038949489593506
loss: 0.7281067371368408
loss: 1.2265386581420898


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-750)... Done. 0.1s


loss: 0.5572206377983093
loss: 0.7795447111129761
loss: 0.7992662191390991
loss: 1.462123155593872
loss: 1.112917184829712
loss: 1.6110022068023682
loss: 0.860112190246582
loss: 0.8586694002151489
loss: 1.0937672853469849
loss: 1.0215548276901245
loss: 1.006108045578003
loss: 1.0640918016433716
loss: 1.0566524267196655
loss: 1.1734707355499268
loss: 0.7359733581542969
loss: 0.9050506353378296
loss: 0.9638080596923828
loss: 1.162158727645874
loss: 0.7069923281669617
loss: 1.1629191637039185
loss: 0.9148150682449341
loss: 1.0401872396469116
loss: 0.9782432317733765
loss: 1.0496795177459717
loss: 0.7712960839271545
loss: 0.7896177172660828
loss: 0.46046990156173706
loss: 1.4157459735870361
loss: 1.3838685750961304
loss: 0.4573649764060974


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-780)... Done. 0.1s


loss: 1.5040498971939087
loss: 1.6830841302871704
loss: 1.2574043273925781
loss: 1.1339070796966553
loss: 1.6291887760162354
loss: 1.0272672176361084
loss: 1.1465948820114136
loss: 0.9603737592697144
loss: 0.9313439130783081
loss: 1.281110167503357
loss: 1.0244967937469482
loss: 1.0543431043624878
loss: 1.2704188823699951
loss: 1.0062493085861206
loss: 0.8259531259536743
loss: 1.023712396621704
loss: 0.796432375907898
loss: 1.0162265300750732
loss: 1.8898260593414307
loss: 0.4497625231742859
loss: 1.4662566184997559
loss: 0.874554455280304
loss: 1.1128747463226318
loss: 0.9755358695983887
loss: 1.6760133504867554
loss: 0.547770619392395
loss: 1.124381184577942
loss: 0.9802995920181274
loss: 0.7026059627532959
loss: 1.0261178016662598


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-810)... Done. 0.1s


loss: 1.0562458038330078
loss: 0.6800273656845093
loss: 0.9203351736068726
loss: 1.154648780822754
loss: 0.9011041522026062
loss: 0.8261197805404663
loss: 0.7609407305717468
loss: 0.996455192565918
loss: 0.9838223457336426
loss: 1.1728990077972412
loss: 1.8817850351333618
loss: 1.02897310256958
loss: 1.114203691482544
loss: 1.2700555324554443
loss: 1.3607136011123657
loss: 0.7964667081832886
loss: 0.7471634149551392
loss: 1.028788447380066
loss: 0.8187958002090454
loss: 1.156801700592041
loss: 0.9461399912834167
loss: 0.943000316619873
loss: 1.1519365310668945
loss: 0.9225489497184753
loss: 1.1790342330932617
loss: 1.0619986057281494
loss: 1.0097581148147583
loss: 0.8583358526229858
loss: 0.7923102378845215
loss: 1.448556661605835


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-840)... Done. 0.1s


loss: 0.7976875901222229
loss: 0.8082131147384644
loss: 1.3968759775161743
loss: 1.0763789415359497
loss: 0.8025921583175659
loss: 0.9928539991378784
loss: 0.6503741145133972
loss: 0.9815096855163574
loss: 1.1379200220108032
loss: 1.1427561044692993
loss: 0.877886176109314
loss: 0.8649905323982239
loss: 0.6715602278709412
loss: 0.7721848487854004
loss: 1.099636197090149
loss: 0.5823196172714233
loss: 0.9397844076156616
loss: 1.0513367652893066
loss: 0.6853055357933044
loss: 0.7290955185890198
loss: 0.9166972041130066
loss: 0.632881760597229
loss: 0.8324154615402222
loss: 1.9033197164535522
loss: 0.6669576168060303
loss: 1.1221449375152588
loss: 1.2976062297821045
loss: 0.5142797231674194
loss: 1.0759981870651245
loss: 0.8856493234634399


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-870)... Done. 0.1s


loss: 0.7735857963562012
loss: 1.7212311029434204
loss: 0.881303608417511
loss: 0.4596692621707916
loss: 1.0489166975021362
loss: 1.141600489616394
loss: 0.9605436325073242
loss: 1.0988366603851318
loss: 0.5161194801330566
loss: 0.25360107421875
loss: 0.8137141466140747
loss: 0.7046750783920288
loss: 0.3866201341152191
loss: 0.8508655428886414
loss: 0.9334490895271301
loss: 1.1724858283996582
loss: 0.5024094581604004
loss: 0.874512791633606
loss: 0.4473083019256592
loss: 1.371951699256897
loss: 1.066149353981018
loss: 1.105224847793579
loss: 1.2521488666534424
loss: 0.5673980116844177
loss: 1.2730928659439087
loss: 1.3637902736663818
loss: 1.6158803701400757
loss: 1.5437211990356445
loss: 1.2578872442245483
loss: 1.128680944442749


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-900)... Done. 0.1s


loss: 0.7063612341880798
loss: 1.374328851699829
loss: 1.076373815536499
loss: 0.9027762413024902
loss: 0.38892102241516113
loss: 0.8857256174087524
loss: 1.3693790435791016
loss: 0.8955772519111633
loss: 1.075865387916565
loss: 0.9252927899360657
loss: 0.2615390717983246
loss: 0.7410931587219238
loss: 0.7875353693962097
loss: 0.852269172668457
loss: 1.3718106746673584
loss: 1.4287455081939697
loss: 0.6574356555938721
loss: 1.4120427370071411
loss: 0.4901684820652008
loss: 0.8935451507568359
loss: 1.0536129474639893
loss: 0.6427475214004517
loss: 0.5678585171699524
loss: 0.6042290925979614
loss: 0.6905237436294556
loss: 1.0135974884033203
loss: 1.529565691947937
loss: 1.085613489151001
loss: 1.482613444328308
loss: 0.39598268270492554


[34m[1mwandb[0m: Adding directory to artifact (./output/checkpoint-930)... Done. 0.1s


loss: 0.8567681312561035
loss: 0.5602818727493286
loss: 1.259456992149353
loss: 0.3017203211784363
loss: 0.7938013672828674
loss: 0.7035360336303711
loss: 0.8280556797981262
loss: 1.7363885641098022
loss: 0.7760562896728516
loss: 0.8345940709114075
loss: 1.127901315689087
skipping batch item


ValueError: ignored

In [None]:
#play
import matplotlib.pyplot as plt
from IPython.display import display as ipy_display, clear_output
#import gymnasium as gym
# build the environment
max_ep_len = 1000
device = 'cpu'
model = model.to('cpu')
scale = 1000.0  # normalization for rewards/returns
TARGET_RETURN = 900 / scale  # evaluation is conditioned on a return of 12000, scaled accordingly

env =  gym.make('CarRacing-v2', render_mode='rgb_array', continuous=False) #,

state_dim = 96*96*3
act_dim = 1
# Create the decision transformer model

# Interact with the environment and create a video
episode_return, episode_length = 0, 0
[state, _] = env.reset()
state = prepare_observation_array(state)
target_return = torch.tensor(TARGET_RETURN, device=device, dtype=torch.float32).reshape(1, 1)
states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32)
actions = torch.zeros((0, act_dim),  device=device, dtype=torch.long)
rewards = torch.zeros(0, device=device, dtype=torch.float32)
print_every = 10
iter = 0

timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)
for t in range(max_ep_len):
    iter += 1
    actions = torch.cat([actions, torch.zeros((1, act_dim), dtype=torch.long,  device=device)], dim=0)
    rewards = torch.cat([rewards, torch.zeros(1, device=device)])

    action = get_action(
        model,
        states,
        actions,
        rewards,
        target_return,
        timesteps,
    )

    action =   torch.argmax(action).item() # action.detach().cpu().numpy()

    actions[-1] = torch.tensor(action, dtype=torch.long)

    state, reward, done, _, _ = env.step(action)

    if iter % print_every ==0:
      image = env.render()
      clear_output(wait=True)
      plt.imshow(image)
      plt.axis('off')  # Hide the axis
      display(plt.gcf())



    state = prepare_observation_array(state)
    cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim)
    states = torch.cat([states, cur_state], dim=0)
    rewards[-1] = reward

    pred_return = target_return[0, -1] - (reward / scale)
    target_return = torch.cat([target_return, pred_return.reshape(1, 1)], dim=1)
    timesteps = torch.cat([timesteps, torch.ones((1, 1), device=device, dtype=torch.long) * (t + 1)], dim=1)

    episode_return += reward
    episode_length += 1

    if done:
        break