# Setup

In [1]:
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
import wandb
from utilities import *
from config import *
from dataloading import *
from tqdm import tqdm
from transformer import *
import os

Loading data...


In [2]:
print("Logging in...")
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Logging in...


[34m[1mwandb[0m: Currently logged in as: [33mmidataur[0m ([33mknot-theory[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
# assumes you're using the transformer
# if you're using the MLP, you'll need to change the data pipeline and the final dimension
# also you can modify the transformer config in the transformer.py file

# setup the model
model = BigramLanguageModel()

# cuda? (gpu)
if torch.cuda.is_available():
  device = "cuda:0"
else:
  device = "cpu"

# send to gpu (maybe)
model = nn.DataParallel(model)
model = model.to(device)

# load the model
filename = PATH + "/model/" + MODELNAME + ".pth"
if os.path.isfile(filename):
    model.load_state_dict(torch.load(filename, map_location=torch.device(device)))
else:
   raise Exception("Model not found")

print(model)

DataParallel(
  (module): BigramLanguageModel(
    (token_embedding_table): Embedding(16, 384)
    (position_embedding): Embedding(6, 384)
    (sa_heads): MultiHeadAttention(
      (heads): ModuleList(
        (0-5): 6 x Head(
          (key): Linear(in_features=384, out_features=64, bias=False)
          (query): Linear(in_features=384, out_features=64, bias=False)
          (value): Linear(in_features=384, out_features=64, bias=False)
          (dropout): Dropout(p=0, inplace=False)
        )
      )
      (proj): Linear(in_features=384, out_features=384, bias=True)
      (dropout): Dropout(p=0, inplace=False)
    )
    (blocks): Sequential(
      (0): Block(
        (sa): MultiHeadAttention(
          (heads): ModuleList(
            (0-5): 6 x Head(
              (key): Linear(in_features=384, out_features=64, bias=False)
              (query): Linear(in_features=384, out_features=64, bias=False)
              (value): Linear(in_features=384, out_features=64, bias=False)
          

In [4]:
import plotly.io as pio
pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

Using renderer: notebook_connected


In [5]:
import circuitsvis as cv
# Testing that the library works
cv.examples.hello("World")

In [6]:
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import plotly.express as px

from jaxtyping import Float
from functools import partial

In [14]:
# import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix, HookedTransformerConfig

In [8]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x2d38fa2d0>

In [15]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [54]:
from collections import OrderedDict
import re

# have to convert the state dict to the HookedTransformer format
def translate(key, value):
    direct_swaps = {
        "module.token_embedding_table.weight": "embed.W_E",
        "module.position_embedding.weight": "pos_embed.W_pos",
        "module.lm_head.weight": "unembed.W_U",
        "module.lm_head.bias": "unembed.b_U",
    }

    key = direct_swaps[key] if key in direct_swaps.keys() else key

    # need to reshape the unembedding matrix
    if key == "unembed.W_U":
        value = value.reshape(-1, 1)

    # attention heads
    matches = re.match(r"/module\.sa_heads\.heads\.(\d)\.(.+)", key)
    if matches:
        block = matches.group(1)
        subkey = matches.group(2)

        subkey_swaps = {
            "key.weight": "W_K",
            "query.weight": "W_Q",
            "value.weight": "W_V",
            
        }

    return (key, value)

state_dict = torch.load(filename, map_location=torch.device(device))
state_dict.keys()
state_dict = OrderedDict(translate(k, v) for k, v in state_dict.items())
state_dict.keys()

0
0
0
1
1
1
2
2
2
3
3
3
4
4
4
5
5
5


odict_keys(['embed.W_E', 'pos_embed.W_pos', 'module.sa_heads.heads.0.key.weight', 'module.sa_heads.heads.0.query.weight', 'module.sa_heads.heads.0.value.weight', 'module.sa_heads.heads.1.key.weight', 'module.sa_heads.heads.1.query.weight', 'module.sa_heads.heads.1.value.weight', 'module.sa_heads.heads.2.key.weight', 'module.sa_heads.heads.2.query.weight', 'module.sa_heads.heads.2.value.weight', 'module.sa_heads.heads.3.key.weight', 'module.sa_heads.heads.3.query.weight', 'module.sa_heads.heads.3.value.weight', 'module.sa_heads.heads.4.key.weight', 'module.sa_heads.heads.4.query.weight', 'module.sa_heads.heads.4.value.weight', 'module.sa_heads.heads.5.key.weight', 'module.sa_heads.heads.5.query.weight', 'module.sa_heads.heads.5.value.weight', 'module.sa_heads.proj.weight', 'module.sa_heads.proj.bias', 'module.blocks.0.sa.heads.0.key.weight', 'module.blocks.0.sa.heads.0.query.weight', 'module.blocks.0.sa.heads.0.value.weight', 'module.blocks.0.sa.heads.1.key.weight', 'module.blocks.0.sa.

In [47]:
state_dict["unembed.W_U"].shape

torch.Size([384, 1])

In [48]:
cfg = HookedTransformerConfig(
    n_layers=n_blocks,
    d_model=n_embed,
    d_head=n_embed // n_head,
    n_ctx=MAX_LENGTH,
    act_fn="relu",
    d_vocab=vocab_size,
    d_vocab_out=1,
)

hooked = HookedTransformer(cfg)

hooked.load_and_process_state_dict(state_dict)



wowza torch.Size([384, 1]) 



In [16]:
logits, cache = model.run_with_cache([12,4,1,8,7,0])

AttributeError: 'DataParallel' object has no attribute 'run_with_cache'