In [31]:
import os; # os.environ['ACCELERATE_DISABLE_RICH'] = "1"
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

import torch
from torch import nn
from torch.nn import functional as F
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
from transformer_lens import HookedTransformer, HookedTransformerConfig

from typing import Optional, Union, List, Tuple, Callable, Any

from dataclasses import dataclass, replace
import numpy as np
import einops

from tqdm.notebook import trange

import time
import pandas as pd
from functools import reduce

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import importlib
import plotly_utils
import utils

importlib.reload(plotly_utils)
importlib.reload(utils)

from utils import *
from plotly_utils import imshow, line, hist, scatter

import matplotlib.pyplot as plt

if torch.cuda.is_available():
  DEVICE = 'cuda'
else:
  DEVICE = 'cpu'

In [40]:
cfg = HookedTransformerConfig(
    n_layers=2,
    d_model=16,
    d_head=8,
    n_heads=2,
    d_mlp=4*16,
    d_vocab=10,
    n_ctx=6,
    act_fn="relu",
    normalization_type=None,
)

SEED=1
torch.manual_seed(SEED)
model = HookedTransformer(cfg).to(DEVICE)
torch.set_grad_enabled(True)

dataset = MaxDataset(size=10000, config=cfg, device=DEVICE)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)

Moving model to device:  cuda
tensor([[3, 2, 1, 3, 3, 0],
        [4, 1, 3, 1, 1, 0],
        [3, 2, 4, 1, 1, 0]], device='cuda:0')


In [None]:
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train(model, dataloader, optimizer, epochs=1):
    for epoch in range(epochs):
        for batch in dataloader:
            optimizer.zero_grad()
            batch = batch.to(DEVICE)
            labels = batch.max(dim=-1).values
            logits = model(batch)
            loss = F.cross_entropy(logits[:, -1], labels)
            loss.backward()
            optimizer.step()

train(model, dataloader, optimizer, epochs=1)

In [48]:
# model(torch.tensor([[1, 2, 3, 3, 3, 0]]).to(DEVICE))[:, -1].argmax(dim=-1)
for name, param in model.named_parameters():
    print(name, param.shape)

embed.W_E torch.Size([10, 16])
pos_embed.W_pos torch.Size([6, 16])
blocks.0.attn.W_Q torch.Size([2, 16, 8])
blocks.0.attn.W_K torch.Size([2, 16, 8])
blocks.0.attn.W_V torch.Size([2, 16, 8])
blocks.0.attn.W_O torch.Size([2, 8, 16])
blocks.0.attn.b_Q torch.Size([2, 8])
blocks.0.attn.b_K torch.Size([2, 8])
blocks.0.attn.b_V torch.Size([2, 8])
blocks.0.attn.b_O torch.Size([16])
blocks.0.mlp.W_in torch.Size([16, 64])
blocks.0.mlp.b_in torch.Size([64])
blocks.0.mlp.W_out torch.Size([64, 16])
blocks.0.mlp.b_out torch.Size([16])
blocks.1.attn.W_Q torch.Size([2, 16, 8])
blocks.1.attn.W_K torch.Size([2, 16, 8])
blocks.1.attn.W_V torch.Size([2, 16, 8])
blocks.1.attn.W_O torch.Size([2, 8, 16])
blocks.1.attn.b_Q torch.Size([2, 8])
blocks.1.attn.b_K torch.Size([2, 8])
blocks.1.attn.b_V torch.Size([2, 8])
blocks.1.attn.b_O torch.Size([16])
blocks.1.mlp.W_in torch.Size([16, 64])
blocks.1.mlp.b_in torch.Size([64])
blocks.1.mlp.W_out torch.Size([64, 16])
blocks.1.mlp.b_out torch.Size([16])
unembed.W_U t