### V 17:54

### Install initial environment in Google Colab

In [1]:
import sys
import os

if 'google.colab' in sys.modules:
  if not os.path.exists('/content/.already_installed'):
    !git clone https://github.com/FlutterbaseDotCom/hdt
    !apt-get install -y swig
    !pip install -r hdt/requirements.txt
    with open('/content/.already_installed', 'w') as f:
        f.write('done')
  %cd /content/hdt

### Imports

In [2]:
import os
import random
from dataclasses import dataclass

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import wandb
from datasets import Dataset
from stable_baselines3 import DQN
from stable_baselines3.common.torch_layers import NatureCNN
from stable_baselines3.common.vec_env import DummyVecEnv, VecTransposeImage
from torch.utils.data import Subset
from transformers import Trainer, TrainingArguments

from dt.configuration_decision_transformer import DecisionTransformerConfig
from dt.modeling_decision_transformer import DecisionTransformerModel
from extract_cnn import prepare_observation_array
from dt.trainable_dt import DecisionTransformerGymDataCollator, TrainableDT

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
#car_racing_15_100
#offline_car_racing_150_1000

import utils.storage as storage
features = storage.load_dataset('car_racing_15_100')
print(len(features["observations"]))

15


### Persist Dataset

### Split dataset

In [4]:
from utils.config import CONFIG, WANDB_ID, WNDB_NAME
import os

os.environ["WANDB_DISABLED"] = "false"

env =  gym.make('CarRacing-v2', continuous=False) #, render_mode='human'


In [5]:
os.environ['WANDB_NOTEBOOK_NAME'] = 'DT_train.ipynb'
os.environ['WANDB_MODE']='online'
os.environ["WANDB_LOG_MODEL"] = "checkpoint"

wandb.login(key="f060d3284088ffaf4624e2de8b236f39711a99a2")
wandb.init(resume=WANDB_ID,
           name = WNDB_NAME,
           mode="online",
           entity="yakiv",
            project="CarRacingDT",
            #resume= "allow"
            config=CONFIG
           )


[34m[1mwandb[0m: Currently logged in as: [33myakiv[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/jacob/.netrc


### Train

In [6]:

from utils.dataset_wrappers import DummyDataset
from utils.dataset_wrappers import CarRacingFeatureDataset

feature_dataset = CarRacingFeatureDataset(src=features)
collator = DecisionTransformerGymDataCollator(feature_dataset, max_len=CONFIG["max_length"],   max_ep_len=CONFIG["max_ep_len"],)

dt_config = DecisionTransformerConfig(state_dim=collator.state_dim, act_dim=collator.act_dim,
                                      max_length = CONFIG["max_length"],
                                      max_ep_len = CONFIG["max_ep_len"],  
                                      )
print(dt_config.to_dict())

model = TrainableDT(dt_config)


training_args = TrainingArguments(
    output_dir="output/",
    report_to="wandb",
    save_steps=CONFIG["save_steps"],
    remove_unused_columns=False,
    optim="adamw_torch",
    num_train_epochs=CONFIG["num_train_epochs"],
    per_device_train_batch_size=CONFIG["per_device_train_batch_size"],
    learning_rate=CONFIG["learning_rate"],
    weight_decay=CONFIG["weight_decay"],
    warmup_ratio=CONFIG["warmup_ratio"],
    max_grad_norm=CONFIG["max_grad_norm"],
    logging_steps=CONFIG["LOG_INTERVAL"],
)

    

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=DummyDataset(len(feature_dataset)), #there is a hack there - no need to pass data here, its linked in collator
    data_collator=collator,

)

trainer.train()


ImportError: cannot import name 'CarRacingFeatureDataset' from 'utils.dataset_wrappers' (/Users/jacob/Documents/T/hdt/utils/dataset_wrappers.py)

In [None]:
#play
import matplotlib.pyplot as plt
from IPython.display import display as ipy_display, clear_output
#import gymnasium as gym
# build the environment
max_ep_len = 1000
device = 'cpu'
model = model.to('cpu')
scale = 1000.0  # normalization for rewards/returns
TARGET_RETURN = 900 / scale  # evaluation is conditioned on a return of 12000, scaled accordingly

env =  gym.make('CarRacing-v2', render_mode='rgb_array', continuous=False) #, 

state_dim = 96*96*3
act_dim = 1
# Create the decision transformer model

# Interact with the environment and create a video
episode_return, episode_length = 0, 0
[state, _] = env.reset()
state = prepare_observation_array(state)
target_return = torch.tensor(TARGET_RETURN, device=device, dtype=torch.float32).reshape(1, 1)
states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32)
actions = torch.zeros((0, act_dim),  device=device, dtype=torch.long)
rewards = torch.zeros(0, device=device, dtype=torch.float32)
print_every = 10
iter = 0

timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)
for t in range(max_ep_len):
    iter += 1
    actions = torch.cat([actions, torch.zeros((1, act_dim), dtype=torch.long,  device=device)], dim=0)
    rewards = torch.cat([rewards, torch.zeros(1, device=device)])

    action = get_action(
        model,
        states,
        actions,
        rewards,
        target_return,
        timesteps,
    )
    
    action =   torch.argmax(action).item() # action.detach().cpu().numpy()
    
    actions[-1] = torch.tensor(action, dtype=torch.long) 

    state, reward, done, _, _ = env.step(action)
    
    if iter % print_every ==0:
      image = env.render()
      clear_output(wait=True)
      plt.imshow(image)
      plt.axis('off')  # Hide the axis
      display(plt.gcf())
    
    

    state = prepare_observation_array(state)
    cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim)
    states = torch.cat([states, cur_state], dim=0)
    rewards[-1] = reward

    pred_return = target_return[0, -1] - (reward / scale)
    target_return = torch.cat([target_return, pred_return.reshape(1, 1)], dim=1)
    timesteps = torch.cat([timesteps, torch.ones((1, 1), device=device, dtype=torch.long) * (t + 1)], dim=1)

    episode_return += reward
    episode_length += 1

    if done:
        break