In [1]:
import numpy as np
import torch
import glob
import os

from sar_transformer.dataset import *
from config import EnvironmentConfig, TransformerModelConfig
from models.trajectory_transformer import AlgorithmDistillationTransformer
from generation import *

pygame 2.3.0 (SDL 2.24.2, Python 3.8.8)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
dataset = HistoryDataset("test_histories", n_episodes_per_seq=15)
train_loader = create_history_dataloader(dataset, 64, 512*64)

In [3]:
context_len = dataset.n_episodes_per_seq * dataset.episode_length * 3 - 2
env = DarkRoom(12, 2, 12, seed=500)

In [4]:
environment_config = EnvironmentConfig(
    env_id="Graph_DarkRoom",
    env=env,
    device="cuda")

transformer_model_config = TransformerModelConfig(
    d_model=64,
    n_heads=4,
    d_mlp=2048,
    n_layers=4,
    layer_norm=True,
    time_embedding_type="embedding",
    state_embedding_type="linear",
    n_ctx=context_len,
    device="cuda",
)

model = AlgorithmDistillationTransformer(environment_config, transformer_model_config)

In [5]:
model

AlgorithmDistillationTransformer(
  (action_embedding): Sequential(
    (0): Embedding(3, 64)
  )
  (time_embedding): Embedding(13, 64)
  (state_embedding): Linear(in_features=12, out_features=64, bias=False)
  (transformer): HookedTransformer(
    (embed): Identity()
    (hook_embed): HookPoint()
    (pos_embed): PosEmbedTokens()
    (hook_pos_embed): HookPoint()
    (blocks): ModuleList(
      (0-3): 4 x TransformerBlock(
        (ln1): LayerNorm(
          (hook_scale): HookPoint()
          (hook_normalized): HookPoint()
        )
        (ln2): LayerNorm(
          (hook_scale): HookPoint()
          (hook_normalized): HookPoint()
        )
        (attn): Attention(
          (hook_k): HookPoint()
          (hook_q): HookPoint()
          (hook_v): HookPoint()
          (hook_z): HookPoint()
          (hook_attn_scores): HookPoint()
          (hook_pattern): HookPoint()
          (hook_result): HookPoint()
        )
        (mlp): MLP(
          (hook_pre): HookPoint()
          

In [6]:
from sar_transformer.trainer import train
model = train(model, train_loader, environment_config, lr=1e-3)

TRAIN - Epoch 1: 0.5171: 100%|██████████| 512/512 [01:46<00:00,  4.83it/s]
TRAIN - Epoch 2: 0.5188: 100%|██████████| 512/512 [01:51<00:00,  4.61it/s]
TRAIN - Epoch 3: 0.4298: 100%|██████████| 512/512 [01:54<00:00,  4.47it/s]
TRAIN - Epoch 4: 0.2515: 100%|██████████| 512/512 [02:00<00:00,  4.24it/s]
TRAIN - Epoch 5: 0.2674: 100%|██████████| 512/512 [02:01<00:00,  4.22it/s]
EVAL - Random walk score: 0.8704, AD high score: 1.0000, AD final score: 0.2667: 100%|██████████| 100/100 [00:13<00:00,  7.62it/s]
TRAIN - Epoch 6: 0.1703: 100%|██████████| 512/512 [01:50<00:00,  4.64it/s]
TRAIN - Epoch 7: 0.1332: 100%|██████████| 512/512 [01:57<00:00,  4.35it/s]
TRAIN - Epoch 8: 0.1388: 100%|██████████| 512/512 [01:59<00:00,  4.28it/s]
TRAIN - Epoch 9: 0.1278: 100%|██████████| 512/512 [02:03<00:00,  4.16it/s]
TRAIN - Epoch 10: 0.1671: 100%|██████████| 512/512 [02:02<00:00,  4.18it/s]
EVAL - Random walk score: 0.8701, AD high score: 0.4000, AD final score: 0.0667: 100%|██████████| 100/100 [00:13<00:00

In [None]:
rewards = evaluate_ad_agent(model, environment_config, 499, temp=0.1)

Evaluating AD, Reward: 0.25: 100%|██████████| 499/499 [00:58<00:00,  8.48it/s]


In [None]:
import matplotlib
import matplotlib.pyplot as plt
plt.cla()
out = np.array(rewards).reshape(-1, 10).mean(axis=-1)
plt.plot(out)
plt.savefig("lmao.png")

In [None]:
torch.save(model.state_dict(), "checkpoint.pt")

In [None]:
# Measure random agent success
total_reward = []

for i in range(10000):
    total_reward.append(0)
    obs, _ = env.reset()
    done = False
    while not done:
        action = env.action_space.sample()
        obs, reward, done, _, info = env.step(action)
        total_reward[-1] += reward

avg_reward = sum(total_reward) / len(total_reward)
print(avg_reward)

0.885
