In [1]:
import numpy as np
import torch
import glob
import os

from sar_transformer.dataset import *
from config import EnvironmentConfig, TransformerModelConfig
from models.trajectory_transformer import AlgorithmDistillationTransformer
from generation import *

pygame 2.3.0 (SDL 2.24.2, Python 3.8.8)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
dataset = HistoryDataset("histories")
train_loader = create_history_dataloader(dataset, 128, 256*128)

In [3]:
context_len = dataset.n_episodes_per_seq * dataset.episode_length * 3 - 2
env = DarkKeyDoor(12, 2, 12, seed=500)

In [4]:
environment_config = EnvironmentConfig(
    env_id="Graph_DarkKeyDoor",
    env=env,
    device="cuda")

transformer_model_config = TransformerModelConfig(
    d_model=64,
    n_heads=4,
    d_mlp=256,
    n_layers=4,
    layer_norm=True,
    time_embedding_type="embedding",
    state_embedding_type="linear",
    n_ctx=context_len,
    device="cuda",
)

model = AlgorithmDistillationTransformer(environment_config, transformer_model_config)

In [5]:
model

AlgorithmDistillationTransformer(
  (action_embedding): Sequential(
    (0): Embedding(3, 64)
  )
  (time_embedding): Embedding(13, 64)
  (state_embedding): Linear(in_features=12, out_features=64, bias=False)
  (transformer): HookedTransformer(
    (embed): Identity()
    (hook_embed): HookPoint()
    (pos_embed): PosEmbedTokens()
    (hook_pos_embed): HookPoint()
    (blocks): ModuleList(
      (0-3): 4 x TransformerBlock(
        (ln1): LayerNorm(
          (hook_scale): HookPoint()
          (hook_normalized): HookPoint()
        )
        (ln2): LayerNorm(
          (hook_scale): HookPoint()
          (hook_normalized): HookPoint()
        )
        (attn): Attention(
          (hook_k): HookPoint()
          (hook_q): HookPoint()
          (hook_v): HookPoint()
          (hook_z): HookPoint()
          (hook_attn_scores): HookPoint()
          (hook_pattern): HookPoint()
          (hook_result): HookPoint()
        )
        (mlp): MLP(
          (hook_pre): HookPoint()
          

In [6]:
from sar_transformer.trainer import train
model = train(model, train_loader, environment_config)

Training AD, Epoch 1: 0.5911: 100%|██████████| 256/256 [00:36<00:00,  7.02it/s]
Training AD, Epoch 2: 0.5750: 100%|██████████| 256/256 [00:36<00:00,  6.97it/s]
Training AD, Epoch 3: 0.5477: 100%|██████████| 256/256 [00:38<00:00,  6.68it/s]
Training AD, Epoch 4: 0.5683: 100%|██████████| 256/256 [00:39<00:00,  6.56it/s]
Training AD, Epoch 5: 0.5312: 100%|██████████| 256/256 [00:39<00:00,  6.48it/s]
Evaluating AD, Reward: 0.0: 100%|██████████| 100/100 [00:12<00:00,  8.01it/s]             
Training AD, Epoch 6: 0.5062:  31%|███▏      | 80/256 [00:11<00:27,  6.49it/s]

In [None]:
torch.save(model.state_dict(), "checkpoint.pt")