In [1]:
import torch
import einops

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

In [2]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x1fd404765f0>

In [3]:
LOAD_AND_CONVERT_CHECKPOINT = True

In [88]:

checkpoint = torch.load('models/ckpt_3487k_iters_pre_dropout.pt')

# Print the keys of the checkpoint dictionary
print(checkpoint.keys())
model_state = checkpoint['model']
# for key, value in model_state.items():
#     print(key, value.shape)


dict_keys(['model', 'optimizer', 'model_args', 'iter_num', 'best_val_loss', 'config'])


In [5]:
# checkpoint = torch.load('main_linear_probe.pth')
# print(checkpoint.shape)

In [6]:
# import transformer_lens.utils as utils
# cfg = HookedTransformerConfig(
#     n_layers = 16,
#     d_model = 512,
#     d_head = 64,
#     n_heads = 8,
#     d_mlp = 2048,
#     d_vocab = 61,
#     n_ctx = 59,
#     act_fn="gelu",
#     normalization_type="LNPre"
# )
# model = HookedTransformer(cfg)

In [7]:
# model.load_state_dict(model_state)

In [8]:
# model.load_state_dict(model_state)

In [101]:
def convert_to_transformer_lens_format(in_sd, n_layers=16, n_heads=8):
    out_sd = {}
    out_sd["pos_embed.W_pos"] = in_sd["_orig_mod.transformer.wpe.weight"]
    out_sd["embed.W_E"] = in_sd["_orig_mod.transformer.wte.weight"]

    out_sd["ln_final.w"] = in_sd["_orig_mod.transformer.ln_f.weight"]
    out_sd["ln_final.b"] = torch.zeros_like(in_sd["_orig_mod.transformer.ln_f.weight"])
    out_sd["unembed.W_U"] = in_sd["_orig_mod.lm_head.weight"].T

    for layer in range(n_layers):
        layer_key = f"_orig_mod.transformer.h.{layer}"

        # Layer Norms
        out_sd[f"blocks.{layer}.ln1.w"] = in_sd[f"{layer_key}.ln_1.weight"]
        out_sd[f"blocks.{layer}.ln1.b"] = torch.zeros_like(in_sd[f"{layer_key}.ln_1.weight"])
        out_sd[f"blocks.{layer}.ln2.w"] = in_sd[f"{layer_key}.ln_2.weight"]
        out_sd[f"blocks.{layer}.ln2.b"] = torch.zeros_like(in_sd[f"{layer_key}.ln_2.weight"])

        W = in_sd[f"{layer_key}.attn.c_attn.weight"]
        W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=0)
        W_Q = einops.rearrange(W_Q, "(i h) m->i m h", i=cfg.n_heads)
        W_K = einops.rearrange(W_K, "(i h) m->i m h", i=cfg.n_heads)
        W_V = einops.rearrange(W_V, "(i h) m->i m h", i=cfg.n_heads)
        out_sd[f"blocks.{layer}.attn.W_Q"] = W_Q
        out_sd[f"blocks.{layer}.attn.W_K"] = W_K
        out_sd[f"blocks.{layer}.attn.W_V"] = W_V
        # out_sd[f"blocks.{layer}.attn.b_Q"] = torch.zeros_like(W_Q)
        # out_sd[f"blocks.{layer}.attn.b_K"] = torch.zeros_like(W_K)
        # out_sd[f"blocks.{layer}.attn.b_V"] = torch.zeros_like(W_V)
        W_O = in_sd[f"{layer_key}.attn.c_proj.weight"]
        W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads)
        out_sd[f"blocks.{layer}.attn.W_O"] = W_O

        # MLP Weights
        out_sd[f"blocks.{layer}.mlp.W_in"] = in_sd[f"{layer_key}.mlp.c_fc.weight"].T
        # out_sd[f"blocks.{layer}.mlp.b_in"] = torch.zeros_like(in_sd[f"{layer_key}.mlp.c_fc.weight"][0])
        out_sd[f"blocks.{layer}.mlp.W_out"] = in_sd[f"{layer_key}.mlp.c_proj.weight"].T
        # out_sd[f"blocks.{layer}.mlp.b_out"] = torch.zeros_like(in_sd[f"{layer_key}.mlp.c_proj.weight"][0])


    return out_sd

if LOAD_AND_CONVERT_CHECKPOINT:

    synthetic_checkpoint = model_state
    for name, param in synthetic_checkpoint.items():
        if name.startswith("_orig_mod.transformer.h.0") or not name.startswith("_orig_mod.transformer.h"):
            print(name, param.shape)

    n_heads = 8
    n_layers = 16

    cfg = HookedTransformerConfig(
        n_layers = n_layers,
        d_model = 512,
        d_head = 64,
        n_heads = n_heads,
        d_mlp = 2048,
        d_vocab = 32,
        n_ctx = 1023,
        act_fn="gelu",
        normalization_type="LNPre"
    )
    model = HookedTransformer(cfg)


    model.load_and_process_state_dict(convert_to_transformer_lens_format(synthetic_checkpoint, n_layers=n_layers, n_heads=n_heads))

# An example input
sample_input = torch.tensor([[15, 6, 4, 27, 9, 0, 25, 10, 0, 7, 4, 19]])
# sample_input = torch.tensor([[15, 6, 4, 27, 9]])
# The argmax of the output (ie the most likely next move from each position)
sample_output = torch.tensor([[ 6,  4, 27,  9,  0, 27, 10,  0,  7,  4, 19, 28]]).to("cuda")
model_output = model(sample_input).argmax(dim=-1)
print(model_output)
print(sample_output == model_output)
# print(model.forward(sample_input))

_orig_mod.transformer.wte.weight torch.Size([32, 512])
_orig_mod.transformer.wpe.weight torch.Size([1023, 512])
_orig_mod.transformer.h.0.ln_1.weight torch.Size([512])
_orig_mod.transformer.h.0.attn.c_attn.weight torch.Size([1536, 512])
_orig_mod.transformer.h.0.attn.c_proj.weight torch.Size([512, 512])
_orig_mod.transformer.h.0.ln_2.weight torch.Size([512])
_orig_mod.transformer.h.0.mlp.c_fc.weight torch.Size([2048, 512])
_orig_mod.transformer.h.0.mlp.c_proj.weight torch.Size([512, 2048])
_orig_mod.transformer.ln_f.weight torch.Size([512])
_orig_mod.lm_head.weight torch.Size([32, 512])
torch.Size([512, 512])
torch.Size([512, 512])
torch.Size([512, 512])
torch.Size([8, 512, 64])
torch.Size([8, 512, 64])
torch.Size([8, 512, 64])
torch.Size([512, 512])
torch.Size([8, 64, 512])
tensor([[ 6,  4, 27,  9,  0, 27, 10,  0,  7,  4, 19, 28]], device='cuda:0')
tensor([[True, True, True, True, True, True, True, True, True, True, True, True]],
       device='cuda:0')


In [102]:
# An example input
sample_input = torch.tensor([[15, 6]])
# sample_input = torch.tensor([[15, 6, 4, 27, 9]])
# The argmax of the output (ie the most likely next move from each position)
sample_output = torch.tensor([[21, 41, 40, 34, 40, 41,  3, 11, 21, 43, 40, 21, 28, 50, 33, 50, 33,  5,
         33,  5, 52, 46, 14, 46, 14, 47, 38, 57, 36, 50, 38, 15, 28, 26, 28, 59,
         50, 28, 14, 28, 28, 28, 28, 45, 28, 35, 15, 14, 30, 59, 49, 59, 15, 15,
         14, 15,  8,  7,  8]])
print(model(sample_input).argmax(dim=-1))
# print(model.forward(sample_input))

tensor([[6, 4]], device='cuda:0')


In [103]:
torch.save(model.state_dict(), 'tf_lens_16.pth')

In [105]:
n_heads = 8
n_layers = 16

cfg = HookedTransformerConfig(
    n_layers = n_layers,
    d_model = 512,
    d_head = 64,
    n_heads = n_heads,
    d_mlp = 2048,
    d_vocab = 32,
    n_ctx = 1023,
    act_fn="gelu",
    normalization_type="LNPre"
)
model2 = HookedTransformer(cfg)

model2.load_state_dict(torch.load('tf_lens_16.pth'))

model_output = model2(sample_input).argmax(dim=-1)
print(model_output)


tensor([[6, 4]], device='cuda:0')
