In [2]:
import numpy as np
from stable_baselines3.common.vec_env import DummyVecEnv
import gymnasium as gym
import random

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from Dataloader.SequenceExtractor import SequenceExtractor, collate_fn
from Models.DecisionTransformer import DecisionTransformers

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint

In [3]:
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator().manual_seed(0)

In [4]:
env_id = 'CarRacing-v2'
render_mode = "rgb_array"

env = DummyVecEnv([lambda: gym.make(env_id, render_mode=render_mode)])

In [5]:
sequenceExtractorTrain = SequenceExtractor(env, dataset_len = 32768)
sequenceExtractorVal = SequenceExtractor(env, dataset_len = 2048, starting_num = len(sequenceExtractorTrain))

In [6]:
# Creating a data loader
batch_size = 64
dataloader_train = DataLoader(sequenceExtractorTrain, batch_size=batch_size, 
                        shuffle=True, num_workers=7, collate_fn = collate_fn,
                        worker_init_fn=seed_worker, generator=g)
dataloader_val = DataLoader(sequenceExtractorVal, batch_size=batch_size, 
                        shuffle=False, num_workers=7, collate_fn = collate_fn,
                        worker_init_fn=seed_worker, generator=g)

In [7]:
logger = TensorBoardLogger("tb_logs", name="DecisionTransformers")
checkpoint_callback = ModelCheckpoint(dirpath="checkpoints/", save_top_k=2, monitor="train_loss",filename='{epoch}-{train_loss:.2f}-{val_loss:.2f}')

In [None]:
torch.manual_seed(42)
model = DecisionTransformers(d_model = 256, action_space_dim = env.action_space.shape[0], 
                             observation_space = env.observation_space, max_seq_len = sequenceExtractorTrain.seq_len)  # Example vocab size
trainer = pl.Trainer(max_epochs=200, logger=logger, callbacks=[checkpoint_callback])
trainer.fit(model, dataloader_train, dataloader_val)

Number of learnable parameters for the CNN: 4710656
Number of learnable parameters for the entire architecture: 11163395


In [None]:
trainer.save_checkpoint("checkpoints/DecisionTransformers-Overfitting.ckpt")

In [None]:
checkpoint_callback.best_model_path