# Loading BioGPT into TransformerLens

## Setup

In [1]:
import os

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops

from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens import (
    HookedTransformer,
    HookedTransformerConfig,
)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x76a7399c19c0>

## BioGPT

In [4]:
MODEL_PATH = "microsoft/biogpt"

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
hf_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH)

config = hf_model.config
config

BioGptConfig {
  "_name_or_path": "microsoft/biogpt",
  "activation_dropout": 0.0,
  "architectures": [
    "BioGptForCausalLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-12,
  "layerdrop": 0.0,
  "max_position_embeddings": 1024,
  "model_type": "biogpt",
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "pad_token_id": 1,
  "scale_embedding": true,
  "transformers_version": "4.45.2",
  "use_cache": true,
  "vocab_size": 42384
}

In [29]:
print(hf_model)

BioGptForCausalLM(
  (biogpt): BioGptModel(
    (embed_tokens): BioGptScaledWordEmbedding(42384, 1024, padding_idx=1)
    (embed_positions): BioGptLearnedPositionalEmbedding(1026, 1024)
    (layers): ModuleList(
      (0-23): 24 x BioGptDecoderLayer(
        (self_attn): BioGptSdpaAttention(
          (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
        )
        (activation_fn): GELUActivation()
        (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (fc2): Linear(in_features=4096, out_features=1024, bias=True)
        (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      )
    )
    (layer_norm): LayerNorm((

In [5]:
state_dict_path = "BioGPT_state_dict.pth"
torch.save(hf_model.state_dict(), state_dict_path)

## Load into TransformerLens

In [27]:
head_dim = config.hidden_size // config.num_attention_heads

hooked_config = HookedTransformerConfig(
    n_layers=24,
    d_model=1024,
    d_head=64,
    n_heads=16,
    d_mlp=4096,
    d_vocab=42384,
    n_ctx=1024,
    act_fn='gelu',
    normalization_type="LN"
)
model = HookedTransformer(hooked_config)

In [14]:
def biogpt_to_transformer_lens_format(in_sd, n_layers, n_heads):
    out_sd = {}
    out_sd["pos_embed.W_pos"] = in_sd[f"biogpt.embed_positions.weight"]
    out_sd["embed.W_E"] = in_sd[f"biogpt.embed_tokens.weight"]

    out_sd["ln_final.w"] = in_sd[f"biogpt.layer_norm.weight"]
    out_sd["ln_final.b"] = in_sd[f"biogpt.layer_norm.bias"]
    out_sd["unembed.W_U"] = in_sd[f"output_projection.weight"].T

    for layer in range(n_layers):
        out_sd[f"blocks.{layer}.ln1.w"] = in_sd[f"biogpt.layers.{layer}.fc1.weight"]
        out_sd[f"blocks.{layer}.ln1.b"] = in_sd[f"biogpt.layers.{layer}.fc1.bias"]
        out_sd[f"blocks.{layer}.ln2.w"] = in_sd[f"biogpt.layers.{layer}.fc2.weight"]
        out_sd[f"blocks.{layer}.ln2.b"] = in_sd[f"biogpt.layers.{layer}.fc2.bias"]


        out_sd[f"blocks.{layer}.attn.W_Q"] = einops.rearrange(
            in_sd[f"biogpt.layers.{layer}.self_attn.q_proj.weight"],
            "(n_heads d_head) d_model -> n_heads d_model d_head",
            n_heads=n_heads,
        )
        out_sd[f"blocks.{layer}.attn.b_Q"] = einops.rearrange(
            in_sd[f"biogpt.layers.{layer}.self_attn.q_proj.bias"],
            "(n_heads d_head) -> n_heads d_head",
            n_heads=n_heads,
        )
        out_sd[f"blocks.{layer}.attn.W_K"] = einops.rearrange(
            in_sd[f"biogpt.layers.{layer}.self_attn.k_proj.weight"],
            "(n_heads d_head) d_model -> n_heads d_model d_head",
            n_heads=n_heads,
        )
        out_sd[f"blocks.{layer}.attn.b_K"] = einops.rearrange(
            in_sd[f"biogpt.layers.{layer}.self_attn.k_proj.bias"],
            "(n_heads d_head) -> n_heads d_head",
            n_heads=n_heads,
        )
        out_sd[f"blocks.{layer}.attn.W_V"] = einops.rearrange(
            in_sd[f"biogpt.layers.{layer}.self_attn.v_proj.weight"],
            "(n_heads d_head) d_model -> n_heads d_model d_head",
            n_heads=n_heads,
        )
        out_sd[f"blocks.{layer}.attn.b_V"] = einops.rearrange(
            in_sd[f"biogpt.layers.{layer}.self_attn.v_proj.bias"],
            "(n_heads d_head) -> n_heads d_head",
            n_heads=n_heads,
        )
        out_sd[f"blocks.{layer}.attn.W_O"] = einops.rearrange(
            in_sd[f"biogpt.layers.{layer}.self_attn.out_proj.weight"],
            "(d_model n_heads) d_head -> n_heads d_model d_head",
            n_heads=n_heads,
        )
        out_sd[f"blocks.{layer}.attn.b_O"] = einops.rearrange(
            in_sd[f"biogpt.layers.{layer}.self_attn.out_proj.bias"],
            "(d_model n_heads) -> n_heads d_model",
            n_heads=n_heads,
        )

        out_sd[f"blocks.{layer}.mlp.b_in"] = in_sd[f"biogpt.layers.{layer}.fc1.bias"]
        out_sd[f"blocks.{layer}.mlp.W_in"] = in_sd[f"biogpt.layers.{layer}.fc1.weight"].T
        out_sd[f"blocks.{layer}.mlp.b_out"] = in_sd[f"biogpt.layers.{layer}.fc2.bias"]
        out_sd[f"blocks.{layer}.mlp.W_out"] = in_sd[f"biogpt.layers.{layer}.fc2.weight"].T

    return out_sd

In [15]:
state_dict = torch.load(state_dict_path, weights_only=False)

tl_dict = biogpt_to_transformer_lens_format(state_dict, config.num_hidden_layers, config.num_attention_heads)

In [16]:
for key, value in tl_dict.items():
    if key.startswith("blocks.0."):
        print(key, value.shape)

blocks.0.ln1.w torch.Size([4096, 1024])
blocks.0.ln1.b torch.Size([4096])
blocks.0.ln2.w torch.Size([1024, 4096])
blocks.0.ln2.b torch.Size([1024])
blocks.0.attn.W_Q torch.Size([16, 1024, 64])
blocks.0.attn.b_Q torch.Size([16, 64])
blocks.0.attn.W_K torch.Size([16, 1024, 64])
blocks.0.attn.b_K torch.Size([16, 64])
blocks.0.attn.W_V torch.Size([16, 1024, 64])
blocks.0.attn.b_V torch.Size([16, 64])
blocks.0.attn.W_O torch.Size([16, 64, 1024])
blocks.0.attn.b_O torch.Size([16, 64])
blocks.0.mlp.b_in torch.Size([4096])
blocks.0.mlp.W_in torch.Size([1024, 4096])
blocks.0.mlp.b_out torch.Size([1024])
blocks.0.mlp.W_out torch.Size([4096, 1024])


In [23]:
for name, param in state_dict.items():
    if name.startswith("biogpt.layers.0."):
        print(name, param.shape)

biogpt.layers.0.self_attn.k_proj.weight torch.Size([1024, 1024])
biogpt.layers.0.self_attn.k_proj.bias torch.Size([1024])
biogpt.layers.0.self_attn.v_proj.weight torch.Size([1024, 1024])
biogpt.layers.0.self_attn.v_proj.bias torch.Size([1024])
biogpt.layers.0.self_attn.q_proj.weight torch.Size([1024, 1024])
biogpt.layers.0.self_attn.q_proj.bias torch.Size([1024])
biogpt.layers.0.self_attn.out_proj.weight torch.Size([1024, 1024])
biogpt.layers.0.self_attn.out_proj.bias torch.Size([1024])
biogpt.layers.0.self_attn_layer_norm.weight torch.Size([1024])
biogpt.layers.0.self_attn_layer_norm.bias torch.Size([1024])
biogpt.layers.0.fc1.weight torch.Size([4096, 1024])
biogpt.layers.0.fc1.bias torch.Size([4096])
biogpt.layers.0.fc2.weight torch.Size([1024, 4096])
biogpt.layers.0.fc2.bias torch.Size([1024])
biogpt.layers.0.final_layer_norm.weight torch.Size([1024])
biogpt.layers.0.final_layer_norm.bias torch.Size([1024])


In [28]:
for name, param in model.named_parameters():
    if name.startswith("blocks.0."):
        print(name, param.shape)

blocks.0.ln1.w torch.Size([1024])
blocks.0.ln1.b torch.Size([1024])
blocks.0.ln2.w torch.Size([1024])
blocks.0.ln2.b torch.Size([1024])
blocks.0.attn.W_Q torch.Size([16, 1024, 64])
blocks.0.attn.W_O torch.Size([16, 64, 1024])
blocks.0.attn.b_Q torch.Size([16, 64])
blocks.0.attn.b_O torch.Size([1024])
blocks.0.attn.W_K torch.Size([16, 1024, 64])
blocks.0.attn.W_V torch.Size([16, 1024, 64])
blocks.0.attn.b_K torch.Size([16, 64])
blocks.0.attn.b_V torch.Size([16, 64])
blocks.0.mlp.W_in torch.Size([1024, 4096])
blocks.0.mlp.b_in torch.Size([4096])
blocks.0.mlp.W_out torch.Size([4096, 1024])
blocks.0.mlp.b_out torch.Size([1024])
