In [1]:

import numpy as np
import torch
import glob
import os

from src.config import EnvironmentConfig, TransformerModelConfig
from src.models.trajectory_transformer import ConcatTransformer
from src.generation import *
from src.sar_transformer.trainer import train
from src.sar_transformer.dataset import HistoryDataset, create_history_dataloader

pygame 2.3.0 (SDL 2.24.2, Python 3.8.8)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
batch_size = 64
n_samples = 256
env_name = "dark_room" # "simple_dark_room"

In [3]:
train_dataset = HistoryDataset(
    history_dir=f"data/train_{env_name}",
    n_episodes_per_seq=10 # There should be 
)

train_dataloader = create_history_dataloader(
    dataset=train_dataset,
    batch_size=batch_size,
    n_samples=batch_size * n_samples
)

24000


In [4]:
print(len(train_dataset))

24000


In [5]:
test_dataset = HistoryDataset(
    history_dir=f"data/test_{env_name}",
    n_episodes_per_seq=10 # There should be 
)

test_dataloader = create_history_dataloader(
    dataset=test_dataset,
    batch_size=batch_size,
    n_samples=batch_size * (n_samples // 4)
)

1000


In [6]:
context_len = train_dataset.n_episodes_per_seq * train_dataset.episode_length
print(context_len)
env = SimpleDarkRoom(12, 2, 24, seed=50_000)

environment_config = EnvironmentConfig(
    env_id="Graph_DarkRoom",
    env=env,
    device="cuda")

transformer_model_config = TransformerModelConfig(
    d_model=128,
    n_heads=4,
    d_mlp=2048,
    n_layers=6,
    attn_only=False,
    layer_norm=True,
    time_embedding_type="embedding",
    state_embedding_type="linear",
    n_ctx=context_len,
    device="cuda",
)

model = ConcatTransformer(environment_config, transformer_model_config)

240


In [7]:
for (states, actions, rewards, timesteps) in train_dataloader:
    print(states.dtype)
    embeddings = model.to_tokens(states, actions[:, :-1, 0], rewards[:, :-1, 0], timesteps)
    print(embeddings.shape)
    break

torch.float64
torch.Size([64, 240, 15])


In [8]:
model = train(
    model,
    train_dataloader,
    test_dataloader,
    environment_config,
    lr=1e-3,
    eval_frequency=10,
    num_evals=8,
    eval_length=10,
    eval_temp=1.
)

TRAIN - Epoch: 1, Loss: 0.5071, Acc: 73.6134%: 100%|██████████| 256/256 [00:32<00:00,  7.88it/s]
TEST  - Epoch: 1, Loss: 0.4078, Acc: 79.0316%: 100%|██████████| 64/64 [00:03<00:00, 20.92it/s]
TRAIN - Epoch: 2, Loss: 0.3846, Acc: 80.1572%: 100%|██████████| 256/256 [00:33<00:00,  7.53it/s]
TEST  - Epoch: 2, Loss: 0.3456, Acc: 82.4421%: 100%|██████████| 64/64 [00:03<00:00, 21.14it/s]
TRAIN - Epoch: 3, Loss: 0.3259, Acc: 83.4897%: 100%|██████████| 256/256 [00:35<00:00,  7.26it/s]
TEST  - Epoch: 3, Loss: 0.3015, Acc: 84.9323%: 100%|██████████| 64/64 [00:03<00:00, 20.59it/s]
TRAIN - Epoch: 4, Loss: 0.2915, Acc: 85.4412%: 100%|██████████| 256/256 [00:35<00:00,  7.14it/s]
TEST  - Epoch: 4, Loss: 0.2653, Acc: 86.8028%: 100%|██████████| 64/64 [00:03<00:00, 19.15it/s]
TRAIN - Epoch: 5, Loss: 0.2582, Acc: 87.1478%: 100%|██████████| 256/256 [00:34<00:00,  7.33it/s]
TEST  - Epoch: 5, Loss: 0.2399, Acc: 87.9702%: 100%|██████████| 64/64 [00:02<00:00, 21.61it/s]
EVAL  - Random walk score: 0.9987, AD hi

KeyboardInterrupt: 

In [9]:
from src.sar_transformer.eval import *

means = []
for i in range(10):
    out = evaluate_ad_agent(
        model,
        environment_config,
        n_episodes=10,
        temp=1.
    )
    means.append(sum(out) / len(out))

print(sum(means) / len(means))

EVAL  - Random walk score: 2.0234, AD high score: 8.5000, AD final score: 5.1000: 100%|██████████| 10/10 [00:03<00:00,  2.96it/s]
EVAL  - Random walk score: 1.9987, AD high score: 10.5556, AD final score: 10.4000: 100%|██████████| 10/10 [00:03<00:00,  3.10it/s]
EVAL  - Random walk score: 2.0171, AD high score: 10.4000, AD final score: 10.4000: 100%|██████████| 10/10 [00:03<00:00,  3.12it/s]
EVAL  - Random walk score: 2.0148, AD high score: 9.2857, AD final score: 8.7000: 100%|██████████| 10/10 [00:03<00:00,  3.25it/s]
EVAL  - Random walk score: 0.9897, AD high score: 7.6000, AD final score: 7.6000: 100%|██████████| 10/10 [00:03<00:00,  3.17it/s]
EVAL  - Random walk score: 1.9741, AD high score: 10.4000, AD final score: 10.4000: 100%|██████████| 10/10 [00:02<00:00,  3.41it/s]
EVAL  - Random walk score: 2.0179, AD high score: 9.1111, AD final score: 8.2000: 100%|██████████| 10/10 [00:03<00:00,  3.32it/s]
EVAL  - Random walk score: 1.0054, AD high score: 7.8333, AD final score: 6.8000: 10

7.663636363636362



