In [1]:
import numpy as np
import torch
import glob
import os

from sar_transformer.dataset import *
from config import EnvironmentConfig, TransformerModelConfig
from models.trajectory_transformer import AlgorithmDistillationTransformer
from generation import *

pygame 2.3.0 (SDL 2.24.2, Python 3.8.8)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
dataset = HistoryDataset("histories")
train_loader = create_history_dataloader(dataset, 128, 512*128)

In [3]:
context_len = dataset.n_episodes_per_seq * dataset.episode_length * 3 - 2
env = DarkKeyDoor(12, 2, 12, seed=500)

In [4]:
environment_config = EnvironmentConfig(
    env_id="Graph_DarkKeyDoor",
    env=env,
    device="cuda")

transformer_model_config = TransformerModelConfig(
    d_model=64,
    n_heads=4,
    d_mlp=256,
    n_layers=4,
    layer_norm=True,
    time_embedding_type="embedding",
    state_embedding_type="linear",
    n_ctx=context_len,
    device="cuda",
)

model = AlgorithmDistillationTransformer(environment_config, transformer_model_config)

In [5]:
model

AlgorithmDistillationTransformer(
  (action_embedding): Sequential(
    (0): Embedding(3, 64)
  )
  (time_embedding): Embedding(13, 64)
  (state_embedding): Linear(in_features=12, out_features=64, bias=False)
  (transformer): HookedTransformer(
    (embed): Identity()
    (hook_embed): HookPoint()
    (pos_embed): PosEmbedTokens()
    (hook_pos_embed): HookPoint()
    (blocks): ModuleList(
      (0-3): 4 x TransformerBlock(
        (ln1): LayerNorm(
          (hook_scale): HookPoint()
          (hook_normalized): HookPoint()
        )
        (ln2): LayerNorm(
          (hook_scale): HookPoint()
          (hook_normalized): HookPoint()
        )
        (attn): Attention(
          (hook_k): HookPoint()
          (hook_q): HookPoint()
          (hook_v): HookPoint()
          (hook_z): HookPoint()
          (hook_attn_scores): HookPoint()
          (hook_pattern): HookPoint()
          (hook_result): HookPoint()
        )
        (mlp): MLP(
          (hook_pre): HookPoint()
          

In [6]:
from sar_transformer.trainer import train
train(model, train_loader, env)

Training AD, Epoch 1: 0.5623: 100%|██████████| 512/512 [01:17<00:00,  6.57it/s]
Training AD, Epoch 2: 0.5368: 100%|██████████| 512/512 [01:21<00:00,  6.26it/s]
Training AD, Epoch 3: 0.5267: 100%|██████████| 512/512 [01:21<00:00,  6.31it/s]
Training AD, Epoch 4: 0.5222: 100%|██████████| 512/512 [01:22<00:00,  6.17it/s]
Training AD, Epoch 5: 0.5574: 100%|██████████| 512/512 [01:23<00:00,  6.10it/s]
Training AD, Epoch 6: 0.5407: 100%|██████████| 512/512 [01:24<00:00,  6.09it/s]
Training AD, Epoch 7: 0.5246: 100%|██████████| 512/512 [01:24<00:00,  6.07it/s]
Training AD, Epoch 8: 0.5254: 100%|██████████| 512/512 [01:24<00:00,  6.05it/s]
Training AD, Epoch 9: 0.4671: 100%|██████████| 512/512 [01:28<00:00,  5.77it/s]
Training AD, Epoch 10: 0.4350: 100%|██████████| 512/512 [01:31<00:00,  5.61it/s]
Training AD, Epoch 11: 0.4706: 100%|██████████| 512/512 [01:27<00:00,  5.85it/s]
Training AD, Epoch 12: 0.4111: 100%|██████████| 512/512 [01:27<00:00,  5.87it/s]
Training AD, Epoch 13: 0.4002: 100%|█