In [2]:
import torch
from tinygrad.tensor import Tensor
import numpy as np
from core.model import GPTConfig
from core.model_tiny import GPT
from tinygrad.state import get_parameters, get_state_dict

In [14]:
config = GPTConfig.from_pretrained('EleutherAI/pythia-70m')
gpt_tiny = GPT(config)

In [15]:
get_state_dict(gpt_tiny).keys()

dict_keys(['wte.weight', 'h.0.ln_1.weight', 'h.0.ln_1.bias', 'h.0.attn.c_attn.weight', 'h.0.attn.c_attn.bias', 'h.0.attn.c_proj.weight', 'h.0.attn.c_proj.bias', 'h.0.ln_2.weight', 'h.0.ln_2.bias', 'h.0.mlp.c_fc.weight', 'h.0.mlp.c_fc.bias', 'h.0.mlp.c_proj.weight', 'h.0.mlp.c_proj.bias', 'h.1.ln_1.weight', 'h.1.ln_1.bias', 'h.1.attn.c_attn.weight', 'h.1.attn.c_attn.bias', 'h.1.attn.c_proj.weight', 'h.1.attn.c_proj.bias', 'h.1.ln_2.weight', 'h.1.ln_2.bias', 'h.1.mlp.c_fc.weight', 'h.1.mlp.c_fc.bias', 'h.1.mlp.c_proj.weight', 'h.1.mlp.c_proj.bias', 'h.2.ln_1.weight', 'h.2.ln_1.bias', 'h.2.attn.c_attn.weight', 'h.2.attn.c_attn.bias', 'h.2.attn.c_proj.weight', 'h.2.attn.c_proj.bias', 'h.2.ln_2.weight', 'h.2.ln_2.bias', 'h.2.mlp.c_fc.weight', 'h.2.mlp.c_fc.bias', 'h.2.mlp.c_proj.weight', 'h.2.mlp.c_proj.bias', 'h.3.ln_1.weight', 'h.3.ln_1.bias', 'h.3.attn.c_attn.weight', 'h.3.attn.c_attn.bias', 'h.3.attn.c_proj.weight', 'h.3.attn.c_proj.bias', 'h.3.ln_2.weight', 'h.3.ln_2.bias', 'h.3.mlp.c_

In [16]:
from transformers import GPTNeoXForCausalLM
sd_hf = GPTNeoXForCausalLM.from_pretrained('EleutherAI/pythia-70m').state_dict()

sd_hf.keys()

odict_keys(['gpt_neox.embed_in.weight', 'gpt_neox.layers.0.input_layernorm.weight', 'gpt_neox.layers.0.input_layernorm.bias', 'gpt_neox.layers.0.post_attention_layernorm.weight', 'gpt_neox.layers.0.post_attention_layernorm.bias', 'gpt_neox.layers.0.attention.bias', 'gpt_neox.layers.0.attention.masked_bias', 'gpt_neox.layers.0.attention.rotary_emb.inv_freq', 'gpt_neox.layers.0.attention.query_key_value.weight', 'gpt_neox.layers.0.attention.query_key_value.bias', 'gpt_neox.layers.0.attention.dense.weight', 'gpt_neox.layers.0.attention.dense.bias', 'gpt_neox.layers.0.mlp.dense_h_to_4h.weight', 'gpt_neox.layers.0.mlp.dense_h_to_4h.bias', 'gpt_neox.layers.0.mlp.dense_4h_to_h.weight', 'gpt_neox.layers.0.mlp.dense_4h_to_h.bias', 'gpt_neox.layers.1.input_layernorm.weight', 'gpt_neox.layers.1.input_layernorm.bias', 'gpt_neox.layers.1.post_attention_layernorm.weight', 'gpt_neox.layers.1.post_attention_layernorm.bias', 'gpt_neox.layers.1.attention.bias', 'gpt_neox.layers.1.attention.masked_bias',

In [25]:

sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attention.masked_bias')] # ignore these, just a buffer
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attention.bias')] # same, just the mask (buffer)
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.rotary_emb.inv_freq')] # same, just the mask (buffer)

def rename_key(k):
    # Remove .transformer    
    k = k.replace('gpt_neox.layers', 'h')
    # Attention
    k = k.replace('.attention.dense.', '.attn.c_proj.')
    k = k.replace('.attention.query_key_value.', '.attn.c_attn.')
    # MLP
    k = k.replace('.mlp.dense_h_to_4h.', '.mlp.c_fc.')
    k = k.replace('.mlp.dense_4h_to_h.', '.mlp.c_proj.')
    # LayerNorm
    k = k.replace('.input_layernorm.', '.ln_1.')
    k = k.replace('.post_attention_layernorm.', '.ln_2.')
    # Embedding
    k = k.replace('gpt_neox.embed_in.', 'wte.')
    k = k.replace('gpt_neox.final_layer_norm.', 'ln_f.')
    k = k.replace('embed_out.', 'lm_head.')
    return k

sd_hf_keys_renamed = {rename_key(k):Tensor(v.numpy()) for k,v in sd_hf.items()}
sd_hf_keys_renamed.keys()

dict_keys(['wte.weight', 'h.0.ln_1.weight', 'h.0.ln_1.bias', 'h.0.ln_2.weight', 'h.0.ln_2.bias', 'h.0.attention.bias', 'h.0.attention.masked_bias', 'h.0.attention.rotary_emb.inv_freq', 'h.0.attn.c_attn.weight', 'h.0.attn.c_attn.bias', 'h.0.attn.c_proj.weight', 'h.0.attn.c_proj.bias', 'h.0.mlp.c_fc.weight', 'h.0.mlp.c_fc.bias', 'h.0.mlp.c_proj.weight', 'h.0.mlp.c_proj.bias', 'h.1.ln_1.weight', 'h.1.ln_1.bias', 'h.1.ln_2.weight', 'h.1.ln_2.bias', 'h.1.attention.bias', 'h.1.attention.masked_bias', 'h.1.attention.rotary_emb.inv_freq', 'h.1.attn.c_attn.weight', 'h.1.attn.c_attn.bias', 'h.1.attn.c_proj.weight', 'h.1.attn.c_proj.bias', 'h.1.mlp.c_fc.weight', 'h.1.mlp.c_fc.bias', 'h.1.mlp.c_proj.weight', 'h.1.mlp.c_proj.bias', 'h.2.ln_1.weight', 'h.2.ln_1.bias', 'h.2.ln_2.weight', 'h.2.ln_2.bias', 'h.2.attention.bias', 'h.2.attention.masked_bias', 'h.2.attention.rotary_emb.inv_freq', 'h.2.attn.c_attn.weight', 'h.2.attn.c_attn.bias', 'h.2.attn.c_proj.weight', 'h.2.attn.c_proj.bias', 'h.2.mlp.c_

In [27]:
from tinygrad.state import get_parameters, get_state_dict, load_state_dict, torch_load



load_state_dict(gpt_tiny, sd_hf_keys_renamed, strict=False)

ram used:  0.31 GB, cos                                               : 100%|██████████| 78/78 [00:00<00:00, 2739.84it/s]

loaded weights in 31.08 ms, 0.31 GB loaded at 9.87 GB/s



