[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DalasNoin/arena/blob/main/w2/gpt2.ipynb)

In [15]:
# ! pip install transformers
# ! wget https://raw.githubusercontent.com/callummcdougall/arena-v1/main/w2d2/utils.py
# ! wget https://www.gutenberg.org/files/100/100-0.txt
import torch
from torch import nn
from torch.nn import GELU, Softmax
from dataclasses import dataclass
import transformers
import utils
import matplotlib


device = "cuda" if torch.cuda.is_available() else "cpu"


In [16]:
from gpt import GPT2Attention, GPT2BlockSimon, GPT2MLP, GPT2Model, TransformerConfig

## define gpt2 model

In [17]:
config = TransformerConfig()
model = GPT2Model(config)

In [18]:
gpt2 = transformers.AutoModelForCausalLM.from_pretrained("gpt2")


### Adapting some tools from week 0

In [19]:

import pandas as pd
from itertools import zip_longest
def compare_models(my_model, realmodel):
    
    pretraineddict = dict(realmodel.named_parameters())
    my_state = dict(my_model.named_parameters())


    keys_to_iterate = []

    for pretrainedkey, pretrainedvalue in pretraineddict.items():
        remove_these= [] #["attn.masked_bias", ".attn.bias", "lm_head.weight"]
        for suffix in remove_these:
            if pretrainedkey.endswith(suffix):
                break
        else:
            keys_to_iterate.append(pretrainedkey)

    # match keys
    remaining_keys = list(my_state.keys())
    matched_keys = []
    for key in keys_to_iterate:
        for mykey in remaining_keys:
            if key.endswith("weight") and not (key.endswith("wte.weight") or key.endswith("wpe.weight")):
                # rotate shape
                shape = pretraineddict[key].T.shape
            else:
                shape = pretraineddict[key].shape

            myshape = my_state[mykey].shape
            if shape == myshape:
                matched_keys.append(mykey)
                remaining_keys.remove(mykey)
                break
        else: 
            print(f"No match found for key {key}")
            
    print(f"remaining keys = {remaining_keys}")

    print(f"len(pretraineddict)={len(keys_to_iterate)}\tlen(my_state)={len(my_state)}")
    utils.print_param_count(my_model, realmodel)
    
    df = pd.DataFrame.from_records(
        [(tk, tuple(pretraineddict[tk].shape), mk, tuple(my_state[mk].shape)) for (tk, mk) in zip(keys_to_iterate, matched_keys)],
        columns=["their name", "their shape", "your name", "your shape"],
    )
    # if len(pretraineddict)!= len(my_state):
        # for tk, tv in pretraineddict.items():
        #     print(f"{tk}\t{tuple(tv.shape)}")
    
        # for tk, tv in my_state.items():
        #     print(f"{tk}\t{tuple(tv.shape)}")

    with pd.option_context("display.max_rows", None):  # type: ignore
        display(df)

#  compare_models(model,gpt2)

In [20]:
def copy_weights_simon(my_model: GPT2Model, pretrained_model: nn.Module) -> GPT2Model:
    '''Copy over the weights from gpt to your implementation of gpt.

    gpt should be imported using: 
        gpt = transformers.AutoModelForCausalLM.from_pretrained("gpt2")

    Returns your gpt model, with weights loaded in.'''

    pretraineddict = dict(pretrained_model.named_parameters())
    my_state = dict(my_model.named_parameters())

    keys_to_iterate = []

    for pretrainedkey, pretrainedvalue in pretraineddict.items():
        remove_these= ["attn.masked_bias", ".attn.bias", "lm_head.weight"]
        for suffix in remove_these:
            if pretrainedkey.endswith(suffix):
                break
        else:
            keys_to_iterate.append(pretrainedkey)

    # match keys
    remaining_keys = list(my_state.keys())
    matched_keys = {}
    for key in keys_to_iterate:
        for mykey in remaining_keys:
            if key.endswith("weight") and not (key.endswith("wte.weight") or key.endswith("wpe.weight")):
                # transpose shape
                pretrained_values = pretraineddict[key].T
                print(f"Copied params.T: {key} -> {mykey}")
            else:
                pretrained_values = pretraineddict[key]
                print(f"Copied params: {key} -> {mykey}")

            myshape = my_state[mykey].shape
            if pretrained_values.shape == myshape:
                matched_keys[mykey] = pretrained_values
                remaining_keys.remove(mykey)
                break
        else: 
            print(f"No match found for key {key}")


    # Check the number of params/buffers is correct
    assert len(matched_keys) == len(keys_to_iterate), "Number of layers is wrong. Have you done the prev step correctly?"

    # Initialise an empty dictionary to store the correct key-value pairs
    # state_dict_to_load = {}

    # for mykey, pretrainedkey in zip(matched_keys, keys_to_iterate):
    #     pretrainedvalue = pretraineddict[pretrainedkey]
    #     state_dict_to_load[mykey] = pretrainedvalue

    my_model.load_state_dict(matched_keys)

    return my_model

my_gpt = copy_weights_simon(model, gpt2)

Copied params: transformer.wte.weight -> text_embedding.weight
Copied params: transformer.wpe.weight -> position_embedding.weight
Copied params.T: transformer.h.0.ln_1.weight -> decoder_blocks.0.ln1.weight
Copied params: transformer.h.0.ln_1.bias -> decoder_blocks.0.ln1.bias
Copied params.T: transformer.h.0.attn.c_attn.weight -> decoder_blocks.0.attn.W_QKV.weight
Copied params: transformer.h.0.attn.c_attn.bias -> decoder_blocks.0.attn.W_QKV.bias
Copied params.T: transformer.h.0.attn.c_proj.weight -> decoder_blocks.0.attn.W_O.weight
Copied params: transformer.h.0.attn.c_proj.bias -> decoder_blocks.0.attn.W_O.bias
Copied params.T: transformer.h.0.ln_2.weight -> decoder_blocks.0.ln2.weight
Copied params: transformer.h.0.ln_2.bias -> decoder_blocks.0.ln2.bias
Copied params.T: transformer.h.0.mlp.c_fc.weight -> decoder_blocks.0.mlp.mlp_block.0.weight
Copied params: transformer.h.0.mlp.c_fc.bias -> decoder_blocks.0.mlp.mlp_block.0.bias
Copied params.T: transformer.h.0.mlp.c_proj.weight -> de

In [21]:
import torch as t
def test_load_pretrained_weights(model, tokenizer):
    model.eval()
    device = next(model.parameters()).device
    
    def encode(text: str) -> t.Tensor:
        """Return a Tensor of shape (batch=1, seq)."""
        return tokenizer(text, return_tensors="pt")["input_ids"].to(device)

    prompt = "Former President of the United States of America, George"
    input_ids = encode(prompt)
    with t.inference_mode():
        output = model(input_ids)
        logits = output[0, -1] if isinstance(output, t.Tensor) else output.logits[0, -1]
    topk = t.topk(logits, k=10).indices
    next_tokens = tokenizer.batch_decode(topk.reshape(-1, 1))
    print("Prompt: ", prompt)
    print("Your model's top 10 predictions: ", next_tokens)
    assert " Washington" in next_tokens
    assert " Bush" in next_tokens

tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")
test_load_pretrained_weights(my_gpt, tokenizer)
test_load_pretrained_weights(gpt2, tokenizer)

Prompt:  Former President of the United States of America, George
Your model's top 10 predictions:  [' W', ' H', ' Bush', ' Washington', ' HW', ' Herbert', ' Pat', ' Soros', ' S', ' Wallace']
Prompt:  Former President of the United States of America, George
Your model's top 10 predictions:  [' W', ' H', ' Bush', ' Washington', ' HW', ' Herbert', ' Pat', ' S', ' Soros', ' Wallace']


In [22]:
import sampling
import shakespeare

In [23]:
sampling.sample_tokens(
    model=my_gpt,
    tokenizer=tokenizer,
    initial_text="a group of unicorns",
    max_tokens_generated=70)

"a group of unicorns passing their tabbot into a plantation held in the rear yard of a second estate of Firestone S. Butler and his brothers. Butler realized, however, that taking a long standing shamrooming tradition to its logical end could only get them anything more than a scrap of land they needed to replace bolt-in slaves and others she didn't want"

In [24]:
dataset = shakespeare.ShakespeareDataset(config, use_word_tokenizer=False)

In [42]:
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
from transformers import get_linear_schedule_with_warmup
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import wandb
import os

os.environ["WANDB_NOTEBOOK_NAME"] = "arena week 2"
wandb_key = ""
keyfile = "keystore.yaml"
if not wandb_key and os.path.exists(keyfile):
    import yaml
    keys = yaml.safe_load(open(keyfile,"r"))
    wandb_key = keys["wandb"]
os.environ["WANDB_API_KEY"] = wandb_key

def collate(batch: list):
    # print(batch)
    device = batch[0][0].device
    max_len = max([len(text) for (text, label) in batch])
    batch_size = len(batch)
    new_text = torch.zeros((batch_size, max_len)).long().to(device)
    new_label = torch.zeros((batch_size, max_len)).long().to(device)
    for i, (text, label) in enumerate(batch):
        new_text[i,:len(text)]=text
        new_label[i,:len(label)] = label
    return new_text, new_label


wandb.init(config=config.__dict__)
def train(config: TransformerConfig, dataset: Dataset, model: GPT2Model):
    model.train()
    model.to(config.device)
    wandb.watch(model,log_freq=100)
    dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=collate)
    model.train()
    optimizer = AdamW(params=model.parameters(), lr=2e-5)
    scheduler = get_linear_schedule_with_warmup(optimizer=optimizer,num_warmup_steps=200, num_training_steps=-1
    )
    # criterion = CrossEntropyLoss()
    criterion = CrossEntropyLoss()
    for epoch_idx in range(1):
        for i, (text, target_label) in enumerate(dataloader):
            # print(i, batch["text"].shape, batch["label"].shape)
            label = model.forward(text)
            # target_label=batch["label"]
            # print(f"{label.shape=} {target_label.shape=}")
            loss = criterion(label.transpose(1,2), target_label)
           
            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            if i % 100 == 0:
               wandb.log({"loss": loss})
        # print(f"torch.mean(label)={torch.mean(label)} \t torch.mean(target_label)={torch.mean(target_label)}")
        loss_numpy = loss.detach().cpu().numpy()
        print(loss_numpy)
        if loss_numpy < 1.0:
            break
    return model



model = train(config, dataset, my_gpt)



VBox(children=(Label(value='0.001 MB of 0.007 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.090230…

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016727997916670272, max=1.0…

KeyboardInterrupt: 

In [40]:
train(model=my_gpt, dataset=dataset, tokenizer=tokenizer)

Training epoch 0
0


1it [00:00, 163.08it/s]


AttributeError: 'list' object has no attribute 'size'