Make sure you create a dataset first

python -m src.run_history_gen --env_id ArmedBandit --n_steps 100 --n_actions 8 --n_states 1 --max_env_len 1 --path bandit_hists_train --n_seeds 40000
python -m src.run_history_gen --env_id ArmedBandit --n_steps 100 --n_actions 8 --n_states 1 --max_env_len 1 --path bandit_hists_test --n_seeds 500 --seed_start 40000

In [None]:

import numpy as np
import torch
import glob
import os

from src.config import EnvironmentConfig, TransformerModelConfig
from src.models.trajectory_transformer import AlgorithmDistillationTransformer
from src.generation import *
from src.sar_transformer.trainer import train
from src.sar_transformer.dataset import HistoryDataset, create_history_dataloader

In [None]:
train_dataset = HistoryDataset(
    history_dir="bandit_hists_train",
    n_episodes_per_seq=100 # There should be 
)

train_dataloader = create_history_dataloader(
    dataset=train_dataset,
    batch_size=64,
    n_samples=64*512
)

In [None]:
test_dataset = HistoryDataset(
    history_dir="bandit_hists_test",
    n_episodes_per_seq=100 # There should be 
)

test_dataloader = create_history_dataloader(
    dataset=test_dataset,
    batch_size=64,
    n_samples=64*256
)

In [None]:
context_len = train_dataset.n_episodes_per_seq * train_dataset.episode_length * 3 - 2
print(context_len)
env = MultiArmedBandit(8, seed=50_000)

environment_config = EnvironmentConfig(
    env_id="Graph_ArmedBandit",
    env=env,
    device="cuda")

transformer_model_config = TransformerModelConfig(
    d_model=128,
    n_heads=8,
    d_mlp=128*4,
    n_layers=1,
    attn_only=False,
    layer_norm=True,
    time_embedding_type="embedding",
    state_embedding_type="linear",
    n_ctx=context_len,
    device="cuda",
)

model = AlgorithmDistillationTransformer(environment_config, transformer_model_config)

In [None]:
model

In [None]:
model = train(model, train_dataloader, test_dataloader, environment_config, lr=4e-3)