In [1]:
# ! pip install transformers
# ! wget https://raw.githubusercontent.com/callummcdougall/arena-v1/main/w2d2/utils.py
import torch
from torch import nn
from torch.nn import GELU, Softmax
from dataclasses import dataclass
import transformers
import utils
import matplotlib

device = "cuda" if torch.cuda.is_available() else "cpu"


In [2]:
@dataclass(frozen=True)
class TransformerConfig:
    '''Constants used throughout your decoder-only transformer model.'''

    num_layers: int = 12
    # head_size is not in this config, because in our implementation we're assuming num_heads * head_size = hidden_size
    num_heads: int = 12
    vocab_size: int = 50_257
    # hidden_size is also referred to as embedding_dim, or d_\text{model}d model in some material you might have read.
    hidden_size: int = 768
    # max_seq_len is used just to determine the size of the positional encoding matrix.
    max_seq_len: int = 1024
    dropout: float = 0.1
    layer_norm_epsilon: float = 1e-05
    device: str = "cpu"




## define gpt2 model

In [3]:
class GPT2MLP(nn.Module):
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.dropout = config.dropout
        self.mlp_block = nn.Sequential(
            nn.Linear(self.hidden_size, 4*self.hidden_size),
            GELU(),
            nn.Linear(4*self.hidden_size, self.hidden_size),
            nn.Dropout(self.dropout)
        )
    def forward(self, x: torch.Tensor):
        return self.mlp_block(x)

Q = torch.ones((2,20,4*64))
K = torch.ones((2,10,4*64))
V = torch.ones((2,10,4*64))
num_heads = 4




class GPT2Attention(nn.Module):
    """
    head_size is not in this config, because in our implementation we're assuming num_heads * head_size = hidden_size.
    hidden_size is also referred to as embedding_dim, or d_\text{model}d 
    model in some material you might have read.

    I ignored this for now as it would require changing the masked attention function
    The attention block has two dropout layers: 
    one immediately after the softmax (i.e. before multiplying by V), 
    and one immediately after multiplying with W_O at the very end of the attention block. 
    Note that the dropout layers won't actually affect weight-loading or performance in eval mode 
    (and you should still be able to train your model without them), 
    but all the same it's nice to be able to exactly match GPT's architecture!
    """
    W_QKV: nn.Linear
    W_O: nn.Linear


    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.num_heads = config.num_heads
        self.hidden_size = config.hidden_size
        self.device = config.device
        self.head_size = self.hidden_size // self.num_heads
        self.W_QKV = nn.Linear(self.hidden_size, self.num_heads*self.head_size*3)
        self.dropout1 = nn.Dropout(config.dropout)
        self.W_O = nn.Linear(self.num_heads*self.head_size, self.hidden_size)
        self.dropout2 = nn.Dropout(config.dropout)
        self.softmax = Softmax(dim=3)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        '''
        x: shape (batch, seq, hidden_size)

        Return: shape (batch, seq, hidden_size)
        '''
        # x = x.repeat((1,1,3)) # repeat trice along dim 2
        x = self.W_QKV(x)
        #print(f"{x.shape=} {num_heads=} {self.hidden_size=}")
        Q, K, V = torch.split(x, self.num_heads*self.head_size, 2)
        #print(f"{Q.shape=} {K.shape=} {V.shape=}")
        
        # Z = multihead_masked_attention(Q, K, V, num_heads=self.num_heads, device=self.device)
        batch, target_seq_len = Q.shape[0:2]
        source_seq_len = K.shape[1] 
        head_size = int(Q.shape[-1]/self.num_heads)
        sqrt_d_k = torch.sqrt(torch.tensor(self.head_size))
        # new_shape = (batch, target_seq_len, num_heads, head_size)
        Q = torch.reshape(Q, (batch, target_seq_len, self.num_heads, self.head_size))
        K = torch.reshape(K, (batch, source_seq_len, self.num_heads, self.head_size))
        V = torch.reshape(V, (batch, source_seq_len, self.num_heads, self.head_size))
        # generate mask
        triangular = torch.triu(torch.ones((target_seq_len, source_seq_len), dtype=torch.bool, device=self.device), diagonal=1)
        
        query_key = torch.einsum("abcd,aecd->acbe", Q, K)
        masked_query_key = torch.where(triangular, -torch.inf, query_key)
        masked_query_key = self.softmax((masked_query_key)/sqrt_d_k)
        masked_query_key = self.dropout1(masked_query_key)
        result = torch.einsum("abcd, adbe-> acbe", masked_query_key, V)
        Z = torch.reshape(result, (batch, target_seq_len, self.num_heads * self.head_size))
        Z = self.dropout2(Z)
        #print(f"{Z.shape=}")
        Z = self.W_O(Z)
        return Z

class GPT2BlockSimon(nn.Module):
    
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
        self.attn = GPT2Attention(config)
        self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
        self.mlp = GPT2MLP(config)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


class GPT2Model(nn.Module):

    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.config = config
        
        self.text_embedding = nn.Embedding(
            num_embeddings=self.config.vocab_size,
            embedding_dim=self.config.hidden_size)
        self.position_embedding = nn.Embedding(
            num_embeddings=self.config.max_seq_len,
            embedding_dim=self.config.hidden_size
        )
        list_decoder_blocks = [GPT2BlockSimon(config = self.config) 
                                    for _ in range(self.config.num_layers)]
        self.decoder_blocks = nn.Sequential(*list_decoder_blocks)
        self.final_layer_norm = nn.LayerNorm(normalized_shape=self.config.hidden_size,eps=self.config.layer_norm_epsilon)
        # self.unembed = nn.Linear(self.config.hidden_size, config.vocab_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if len(x.shape) == 1:
            x = torch.unsqueeze(x, dim=0)
        position = torch.arange(x.shape[1], device=self.config.device)
        x = self.text_embedding(x) + self.position_embedding(position)
        # print(f"x.shape={x.shape}")
        x = self.decoder_blocks(x)
        x = self.final_layer_norm(x)
        # x = self.unembed(x) # ,dim=2)
        x = x @ self.text_embedding.weight.T
        return x




In [4]:
config = TransformerConfig()
model = GPT2Model(config)

In [5]:
gpt2 = transformers.AutoModelForCausalLM.from_pretrained("gpt2")


### Adapting some tools from week 0

In [6]:

import pandas as pd
from itertools import zip_longest
def compare_models(my_model, realmodel):
    
    pretraineddict = dict(realmodel.named_parameters())
    my_state = dict(my_model.named_parameters())


    keys_to_iterate = []

    for pretrainedkey, pretrainedvalue in pretraineddict.items():
        remove_these= [] #["attn.masked_bias", ".attn.bias", "lm_head.weight"]
        for suffix in remove_these:
            if pretrainedkey.endswith(suffix):
                break
        else:
            keys_to_iterate.append(pretrainedkey)

    # match keys
    remaining_keys = list(my_state.keys())
    matched_keys = []
    for key in keys_to_iterate:
        for mykey in remaining_keys:
            if key.endswith("weight") and not (key.endswith("wte.weight") or key.endswith("wpe.weight")):
                # rotate shape
                shape = pretraineddict[key].T.shape
            else:
                shape = pretraineddict[key].shape

            myshape = my_state[mykey].shape
            if shape == myshape:
                matched_keys.append(mykey)
                remaining_keys.remove(mykey)
                break
        else: 
            print(f"No match found for key {key}")
            
    print(f"remaining keys = {remaining_keys}")

    print(f"len(pretraineddict)={len(keys_to_iterate)}\tlen(my_state)={len(my_state)}")
    utils.print_param_count(my_model, realmodel)
    
    df = pd.DataFrame.from_records(
        [(tk, tuple(pretraineddict[tk].shape), mk, tuple(my_state[mk].shape)) for (tk, mk) in zip(keys_to_iterate, matched_keys)],
        columns=["their name", "their shape", "your name", "your shape"],
    )
    # if len(pretraineddict)!= len(my_state):
        # for tk, tv in pretraineddict.items():
        #     print(f"{tk}\t{tuple(tv.shape)}")
    
        # for tk, tv in my_state.items():
        #     print(f"{tk}\t{tuple(tv.shape)}")

    with pd.option_context("display.max_rows", None):  # type: ignore
        display(df)

compare_models(model,gpt2)

remaining keys = []
len(pretraineddict)=148	len(my_state)=148
Model 1, total params = 124439808


  shape = pretraineddict[key].T.shape


Unnamed: 0,name_1,shape_1,num_params_1
0,text_embedding.weight,"(50257, 768)",38597376
1,position_embedding.weight,"(1024, 768)",786432
2,decoder_blocks.0.ln1.weight,"(768,)",768
3,decoder_blocks.0.ln1.bias,"(768,)",768
4,decoder_blocks.0.attn.W_QKV.weight,"(2304, 768)",1769472
...,...,...,...
143,decoder_blocks.11.mlp.mlp_block.0.bias,"(3072,)",3072
144,decoder_blocks.11.mlp.mlp_block.2.weight,"(768, 3072)",2359296
145,decoder_blocks.11.mlp.mlp_block.2.bias,"(768,)",768
146,final_layer_norm.weight,"(768,)",768


Model 2, total params = 124439808


Unnamed: 0,num_params_2,shape_2,name_2
0,38597376,"(50257, 768)",transformer.wte.weight
1,786432,"(1024, 768)",transformer.wpe.weight
2,768,"(768,)",transformer.h.0.ln_1.weight
3,768,"(768,)",transformer.h.0.ln_1.bias
4,1769472,"(768, 2304)",transformer.h.0.attn.c_attn.weight
...,...,...,...
143,3072,"(3072,)",transformer.h.11.mlp.c_fc.bias
144,2359296,"(3072, 768)",transformer.h.11.mlp.c_proj.weight
145,768,"(768,)",transformer.h.11.mlp.c_proj.bias
146,768,"(768,)",transformer.ln_f.weight


All parameter counts match!


Unnamed: 0,name_1,shape_1,num_params_1,num_params_2,shape_2,name_2
0,text_embedding.weight,"(50257, 768)",38597376,38597376,"(50257, 768)",transformer.wte.weight
1,position_embedding.weight,"(1024, 768)",786432,786432,"(1024, 768)",transformer.wpe.weight
2,decoder_blocks.0.ln1.weight,"(768,)",768,768,"(768,)",transformer.h.0.ln_1.weight
3,decoder_blocks.0.ln1.bias,"(768,)",768,768,"(768,)",transformer.h.0.ln_1.bias
4,decoder_blocks.0.attn.W_QKV.weight,"(2304, 768)",1769472,1769472,"(768, 2304)",transformer.h.0.attn.c_attn.weight
5,decoder_blocks.0.attn.W_QKV.bias,"(2304,)",2304,2304,"(2304,)",transformer.h.0.attn.c_attn.bias
6,decoder_blocks.0.attn.W_O.weight,"(768, 768)",589824,589824,"(768, 768)",transformer.h.0.attn.c_proj.weight
7,decoder_blocks.0.attn.W_O.bias,"(768,)",768,768,"(768,)",transformer.h.0.attn.c_proj.bias
8,decoder_blocks.0.ln2.weight,"(768,)",768,768,"(768,)",transformer.h.0.ln_2.weight
9,decoder_blocks.0.ln2.bias,"(768,)",768,768,"(768,)",transformer.h.0.ln_2.bias


Unnamed: 0,their name,their shape,your name,your shape
0,transformer.wte.weight,"(50257, 768)",text_embedding.weight,"(50257, 768)"
1,transformer.wpe.weight,"(1024, 768)",position_embedding.weight,"(1024, 768)"
2,transformer.h.0.ln_1.weight,"(768,)",decoder_blocks.0.ln1.weight,"(768,)"
3,transformer.h.0.ln_1.bias,"(768,)",decoder_blocks.0.ln1.bias,"(768,)"
4,transformer.h.0.attn.c_attn.weight,"(768, 2304)",decoder_blocks.0.attn.W_QKV.weight,"(2304, 768)"
5,transformer.h.0.attn.c_attn.bias,"(2304,)",decoder_blocks.0.attn.W_QKV.bias,"(2304,)"
6,transformer.h.0.attn.c_proj.weight,"(768, 768)",decoder_blocks.0.attn.W_O.weight,"(768, 768)"
7,transformer.h.0.attn.c_proj.bias,"(768,)",decoder_blocks.0.attn.W_O.bias,"(768,)"
8,transformer.h.0.ln_2.weight,"(768,)",decoder_blocks.0.ln2.weight,"(768,)"
9,transformer.h.0.ln_2.bias,"(768,)",decoder_blocks.0.ln2.bias,"(768,)"


In [7]:
def copy_weights_simon(my_model: GPT2Model, pretrained_model: nn.Module) -> GPT2Model:
    '''Copy over the weights from gpt to your implementation of gpt.

    gpt should be imported using: 
        gpt = transformers.AutoModelForCausalLM.from_pretrained("gpt2")

    Returns your gpt model, with weights loaded in.'''

    pretraineddict = dict(pretrained_model.named_parameters())
    my_state = dict(my_model.named_parameters())

    keys_to_iterate = []

    for pretrainedkey, pretrainedvalue in pretraineddict.items():
        remove_these= ["attn.masked_bias", ".attn.bias", "lm_head.weight"]
        for suffix in remove_these:
            if pretrainedkey.endswith(suffix):
                break
        else:
            keys_to_iterate.append(pretrainedkey)

    # match keys
    remaining_keys = list(my_state.keys())
    matched_keys = {}
    for key in keys_to_iterate:
        for mykey in remaining_keys:
            if key.endswith("weight") and not (key.endswith("wte.weight") or key.endswith("wpe.weight")):
                # transpose shape
                pretrained_values = pretraineddict[key].T
                print(f"Copied params.T: {key} -> {mykey}")
            else:
                pretrained_values = pretraineddict[key]
                print(f"Copied params: {key} -> {mykey}")

            myshape = my_state[mykey].shape
            if pretrained_values.shape == myshape:
                matched_keys[mykey] = pretrained_values
                remaining_keys.remove(mykey)
                break
        else: 
            print(f"No match found for key {key}")


    # Check the number of params/buffers is correct
    assert len(matched_keys) == len(keys_to_iterate), "Number of layers is wrong. Have you done the prev step correctly?"

    # Initialise an empty dictionary to store the correct key-value pairs
    # state_dict_to_load = {}

    # for mykey, pretrainedkey in zip(matched_keys, keys_to_iterate):
    #     pretrainedvalue = pretraineddict[pretrainedkey]
    #     state_dict_to_load[mykey] = pretrainedvalue

    my_model.load_state_dict(matched_keys)

    return my_model

my_gpt = copy_weights_simon(model, gpt2)

Copied params: transformer.wte.weight -> text_embedding.weight
Copied params: transformer.wpe.weight -> position_embedding.weight
Copied params.T: transformer.h.0.ln_1.weight -> decoder_blocks.0.ln1.weight
Copied params: transformer.h.0.ln_1.bias -> decoder_blocks.0.ln1.bias
Copied params.T: transformer.h.0.attn.c_attn.weight -> decoder_blocks.0.attn.W_QKV.weight
Copied params: transformer.h.0.attn.c_attn.bias -> decoder_blocks.0.attn.W_QKV.bias
Copied params.T: transformer.h.0.attn.c_proj.weight -> decoder_blocks.0.attn.W_O.weight
Copied params: transformer.h.0.attn.c_proj.bias -> decoder_blocks.0.attn.W_O.bias
Copied params.T: transformer.h.0.ln_2.weight -> decoder_blocks.0.ln2.weight
Copied params: transformer.h.0.ln_2.bias -> decoder_blocks.0.ln2.bias
Copied params.T: transformer.h.0.mlp.c_fc.weight -> decoder_blocks.0.mlp.mlp_block.0.weight
Copied params: transformer.h.0.mlp.c_fc.bias -> decoder_blocks.0.mlp.mlp_block.0.bias
Copied params.T: transformer.h.0.mlp.c_proj.weight -> de

In [8]:
import torch as t
def test_load_pretrained_weights(model, tokenizer):
    model.eval()
    device = next(model.parameters()).device
    
    def encode(text: str) -> t.Tensor:
        """Return a Tensor of shape (batch=1, seq)."""
        return tokenizer(text, return_tensors="pt")["input_ids"].to(device)

    prompt = "Former President of the United States of America, George"
    input_ids = encode(prompt)
    with t.inference_mode():
        output = model(input_ids)
        logits = output[0, -1] if isinstance(output, t.Tensor) else output.logits[0, -1]
    topk = t.topk(logits, k=10).indices
    next_tokens = tokenizer.batch_decode(topk.reshape(-1, 1))
    print("Prompt: ", prompt)
    print("Your model's top 10 predictions: ", next_tokens)
    assert " Washington" in next_tokens
    assert " Bush" in next_tokens

tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")
test_load_pretrained_weights(my_gpt, tokenizer)
test_load_pretrained_weights(gpt2, tokenizer)

Prompt:  Former President of the United States of America, George
Your model's top 10 predictions:  [' W', ' H', ' Bush', ' Washington', ' HW', ' Herbert', ' Pat', ' Soros', ' S', ' Wallace']
Prompt:  Former President of the United States of America, George
Your model's top 10 predictions:  [' W', ' H', ' Bush', ' Washington', ' HW', ' Herbert', ' Pat', ' S', ' Soros', ' Wallace']
