## Setup

In [16]:
import os
import sys
from functools import partial
import json
from typing import List, Tuple, Union, Optional, Callable, Dict
import torch as t
from torch import Tensor
import numpy as np
from plotly.subplots import make_subplots
import plotly.express as px
import plotly.graph_objects as go
import einops
from tqdm import tqdm
from jaxtyping import Float, Int, Bool
from pathlib import Path
import pandas as pd
import circuitsvis as cv
import webbrowser
from IPython.display import display
from transformer_lens import utils, ActivationCache, HookedTransformer, HookedTransformerConfig
from transformer_lens.hook_points import HookPoint
from transformer_lens.components import LayerNorm

from model import create_model
from training import train, TrainArgs
from dataset import MinDataset
from plotly_utils import hist, bar, imshow

device = t.device("cuda" if t.cuda.is_available() else "cpu")

MAIN = __name__ == "__main__"

## Dataset

In [17]:
dataset = MinDataset(size=5, max_num=100, length=7, seed=23)

print("Raw List",dataset[0][:-1])
print("Sorted List",dataset[0][:-1].sort().values)
print("Sorted List Min",dataset[0][:-1].min(dim=-1).values)
print("Label",dataset[0][-1])

Raw List tensor([91, 46, 16, 97, 30, 12, 27])
Sorted List tensor([12, 16, 27, 30, 46, 91, 97])
Sorted List Min tensor(12)
Label tensor(12)


## Transformer

In [18]:
import time

print("Starting training")
start = time.time()
args = TrainArgs(
    max_num=50,
    seq_len=4,
    trainset_size=int(60_000),
    valset_size=int(20_000),
    epochs=50,
    batch_size=2048,
    lr=3e-3,
    seed=23,
    d_model=24,
    d_head=24,
    n_layers=1,
    n_heads=1,
    d_mlp=None,
    normalization_type="LN",
    use_wandb=False,
    device=device,
)
model = train(args)

print(f"Training took {time.time() - start:.2f} seconds")

Starting training


Epoch 00, Train loss = 1.9920, Accuracy = 0.609, Val loss = 1.9301, Val Accuracy = 0.637: : 30it [00:01, 23.50it/s]
Epoch 01, Train loss = 0.4379, Accuracy = 0.967, Val loss = 0.4100, Val Accuracy = 0.971: : 30it [00:00, 34.72it/s]
Epoch 02, Train loss = 0.1341, Accuracy = 0.988, Val loss = 0.1445, Val Accuracy = 0.986: : 30it [00:00, 34.02it/s]
Epoch 03, Train loss = 0.0761, Accuracy = 0.990, Val loss = 0.0831, Val Accuracy = 0.989: : 30it [00:00, 34.17it/s]
Epoch 04, Train loss = 0.0544, Accuracy = 0.997, Val loss = 0.0587, Val Accuracy = 0.992: : 30it [00:00, 36.79it/s]
Epoch 05, Train loss = 0.0345, Accuracy = 0.998, Val loss = 0.0375, Val Accuracy = 0.996: : 30it [00:00, 33.76it/s]
Epoch 06, Train loss = 0.0384, Accuracy = 0.995, Val loss = 0.0270, Val Accuracy = 0.997: : 30it [00:00, 37.30it/s]
Epoch 07, Train loss = 0.0241, Accuracy = 0.997, Val loss = 0.0232, Val Accuracy = 0.997: : 30it [00:00, 32.67it/s]
Epoch 08, Train loss = 0.0255, Accuracy = 0.995, Val loss = 0.0189, Val 

Returning best model from epoch 30/50, with accuracy 1.000
Training took 45.14 seconds





In [19]:
# example = t.tensor([45,39]).to(device)
# print("Example", example)
# print("Label", example.min())
# logits = model(example.unsqueeze(0))
# out = t.argmax(logits, dim=-1)
# print("Prediction", out)

In [20]:
# Save the model
filename =  "models/min_model3.pt"
t.save(model.state_dict(), filename)
