In [None]:
import numpy as np
from stable_baselines3.common.vec_env import DummyVecEnv
import gymnasium as gym
import random

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from codebase.Dataloader.SequenceExtractor import SequenceExtractor, collate_fn
from codebase.Models.DecisionTransformer import DecisionTransformers
from codebase.Models.Resnet import CustomResNet
from codebase.Models.PositionalEncoders import SinusoidalPositionalEncoding
from codebase.Models.Transformers import TransformerArchitecture
from codebase.ModelTester import ModelTester

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint

In [None]:
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator().manual_seed(0)

In [None]:
def checkAction(action):
    #checking steering
    if action[0] > 1:
        #print("Invalid steering. ", action[0])
        action[0] = 1
    elif action[0] < -1:
        #print("Invalid steering. ", action[0])
        action[0] = 1
    #checking gas
    if action[1] > 1:
        #print("Invalid gas. ", action[1])
        action[1] = 1
    elif action[1] < 0:
        #print("Invalid gas. ", action[1])
        action[1] = 0
    #checking brake
    if action[2] > 1:
        #print("Invalid brake. ", action[2])
        action[2] = 1
    elif action[2] < 0:
        #print("Invalid brake. ", action[2])
        action[2] = 0
    return action

In [None]:
env_id = 'CarRacing-v2'
render_mode = "rgb_array"

env = DummyVecEnv([lambda: gym.make(env_id, render_mode=render_mode)])
max_seq_len = 16

In [None]:
sequenceExtractorTrain = SequenceExtractor(env, dataset_len = 16, seq_len = max_seq_len)

In [None]:
# Creating a data loader
batch_size = 32
dataloader_train = DataLoader(sequenceExtractorTrain, 
                              batch_size=batch_size, 
                              shuffle=True, 
                              num_workers=7, 
                              collate_fn = collate_fn,
                              worker_init_fn=seed_worker, 
                              generator=g
                             )

In [None]:
logger = TensorBoardLogger("tb_logs", name="DecisionTransformers")
checkpoint_callback = ModelCheckpoint(dirpath="checkpoints/", 
                                      save_top_k=2, 
                                      monitor="train_loss",
                                      filename='{epoch}-{train_loss:.2f}'
                                     )

In [None]:
d_model = 128
torch.manual_seed(42)
action_space_dim = env.action_space.shape[0]
observation_space = env.observation_space

In [None]:
observation_space.shape,

In [None]:
embedding_reward = nn.Linear(1, d_model)
embedding_action = nn.Linear(action_space_dim, d_model)
embedding_observation = CustomResNet(observation_space.shape, features_dim = d_model)

transformer = TransformerArchitecture(
    positional_embedding=SinusoidalPositionalEncoding,
    d_model = d_model, 
    max_step_len=max_seq_len*3, 
    batch_first = True
)

optimizer = torch.optim.Adam

In [None]:
model_tester = ModelTester(
                           env_name = 'CarRacing-v2',
                           actionCheck = checkAction,
                           render_mode = "rgb_array",
                           seq_len = max_seq_len,
                          )

In [None]:
model = DecisionTransformers(
    embedding_reward = embedding_reward,
    embedding_action = embedding_action,
    embedding_observation = embedding_observation,
    transformer = transformer,
    optimizer = optimizer,
    action_space_dim = action_space_dim,
    d_model = d_model,
    modelTester = model_tester
)  

In [None]:
trainer = pl.Trainer(
    max_epochs=300, 
    logger=logger, 
    #callbacks=[checkpoint_callback], 
    #accumulate_grad_batches=1024 // batch_size
)
trainer.fit(model, dataloader_train)

In [None]:
trainer.save_checkpoint("checkpoints/DecisionTransformers-Last.ckpt")

In [None]:
checkpoint_callback.best_model_path