In [3]:
from modeling_gpt_neox import GPTNeoXForCausalLM
model_hf = GPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-70m")

16 2048 10000
tensor([1.0000e+00, 3.1623e-01, 1.0000e-01, 3.1623e-02, 1.0000e-02, 3.1623e-03,
        1.0000e-03, 3.1623e-04])
16 2048 10000
tensor([1.0000e+00, 3.1623e-01, 1.0000e-01, 3.1623e-02, 1.0000e-02, 3.1623e-03,
        1.0000e-03, 3.1623e-04])
16 2048 10000
tensor([1.0000e+00, 3.1623e-01, 1.0000e-01, 3.1623e-02, 1.0000e-02, 3.1623e-03,
        1.0000e-03, 3.1623e-04])
16 2048 10000
tensor([1.0000e+00, 3.1623e-01, 1.0000e-01, 3.1623e-02, 1.0000e-02, 3.1623e-03,
        1.0000e-03, 3.1623e-04])
16 2048 10000
tensor([1.0000e+00, 3.1623e-01, 1.0000e-01, 3.1623e-02, 1.0000e-02, 3.1623e-03,
        1.0000e-03, 3.1623e-04])
16 2048 10000
tensor([1.0000e+00, 3.1623e-01, 1.0000e-01, 3.1623e-02, 1.0000e-02, 3.1623e-03,
        1.0000e-03, 3.1623e-04])


In [4]:
gpt_neox_config = model_hf.config
gpt_neox_config

GPTNeoXConfig {
  "_name_or_path": "EleutherAI/pythia-70m",
  "architectures": [
    "GPTNeoXForCausalLM"
  ],
  "bos_token_id": 0,
  "eos_token_id": 0,
  "hidden_act": "gelu",
  "hidden_size": 512,
  "initializer_range": 0.02,
  "intermediate_size": 2048,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 2048,
  "model_type": "gpt_neox",
  "num_attention_heads": 8,
  "num_hidden_layers": 6,
  "rotary_emb_base": 10000,
  "rotary_pct": 0.25,
  "tie_word_embeddings": false,
  "torch_dtype": "float16",
  "transformers_version": "4.25.1",
  "use_cache": true,
  "use_parallel_residual": true,
  "vocab_size": 50304
}

In [5]:
from model import GPTConfig, GPT

def convert_config(gpt_neox_config):
    # assert gpt_neox_config.use_parallel_residual
    
    return GPTConfig(
        n_embd = gpt_neox_config.hidden_size,
        n_head = gpt_neox_config.num_attention_heads,
        n_layer = gpt_neox_config.num_hidden_layers,
        vocab_size = gpt_neox_config.vocab_size,
        block_size = gpt_neox_config.max_position_embeddings,
        bias = True,
        dropout= 0.0,
        
        n_embd_proj = gpt_neox_config.intermediate_size,
        rotary_pct = gpt_neox_config.rotary_pct,
        
    )
    
gpt_config = convert_config(gpt_neox_config)

In [6]:
config = convert_config(gpt_neox_config)
model = GPT(config)
sd = model.state_dict()


number of parameters: 70.43M


In [7]:
sd_keys = sd.keys()
sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param

sd_keys = [k for k in sd_keys if not k.endswith('.attn.rotary_emb.freqs')] # discard this mask / buffer, not a param
sd_keys = [k for k in sd_keys if not k.endswith('.attn.rotary_emb.scale')] # discard this mask / buffer, not a param
sd_keys = [k for k in sd_keys if not k.endswith('.attn.rotary_emb.inv_freq')] # discard this mask / buffer, not a param

In [8]:

# init a huggingface/transformers model
sd_hf = model_hf.state_dict()


In [9]:

# copy while ensuring all of the parameters are aligned and match in names and shapes
sd_keys_hf = sd_hf.keys()
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attention.masked_bias')] # ignore these, just a buffer
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attention.bias')] # same, just the mask (buffer)
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.rotary_emb.inv_freq')] # same, just the mask (buffer)


In [10]:
def rename_key(k):
    k = k.replace('gpt_neox.layers', 'transformer.h')
    # Attention
    k = k.replace('.attention.dense.', '.attn.c_proj.')
    k = k.replace('.attention.query_key_value.', '.attn.c_attn.')
    # MLP
    k = k.replace('.mlp.dense_h_to_4h.', '.mlp.c_fc.')
    k = k.replace('.mlp.dense_4h_to_h.', '.mlp.c_proj.')
    # LayerNorm
    k = k.replace('.input_layernorm.', '.ln_1.')
    k = k.replace('.post_attention_layernorm.', '.ln_2.')
    # Embedding
    k = k.replace('gpt_neox.embed_in.', 'transformer.wte.')
    k = k.replace('gpt_neox.final_layer_norm.', 'transformer.ln_f.')
    k = k.replace('embed_out.', 'lm_head.')
    return k


In [11]:
import torch

assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
for k in sd_keys_hf:
    # vanilla copy over the other parameters
    new_k = rename_key(k)
    # print(new_k, k)
    assert sd_hf[k].shape == sd[new_k].shape
    with torch.no_grad():
        sd[new_k].copy_(sd_hf[k])


In [12]:
input_ids = torch.randint(0, 1000, (1, 512))
target_ids = torch.randint(0, 1000, (1, 512))

model_hf.train()
model.train()

x = model_hf(input_ids, labels=target_ids)
y = model(input_ids, target_ids)

In [13]:
x.logits

tensor([[[84.8101, 52.7323, 84.5987,  ..., 52.7321, 52.7307, 52.7329],
         [92.6708, 56.9197, 94.8555,  ..., 56.9199, 56.9172, 56.9210],
         [87.4568, 54.4841, 88.2387,  ..., 54.4840, 54.4823, 54.4859],
         ...,
         [95.2659, 60.2413, 95.4781,  ..., 60.2418, 60.2400, 60.2432],
         [89.7589, 57.7161, 89.9661,  ..., 57.7169, 57.7153, 57.7189],
         [92.8971, 58.3084, 93.5553,  ..., 58.3086, 58.3065, 58.3105]]],
       grad_fn=<UnsafeViewBackward0>)

In [14]:
y[0]

tensor([[[84.8101, 52.7323, 84.5987,  ..., 52.7321, 52.7307, 52.7329],
         [92.6708, 56.9197, 94.8555,  ..., 56.9199, 56.9172, 56.9210],
         [87.4568, 54.4841, 88.2387,  ..., 54.4840, 54.4823, 54.4859],
         ...,
         [95.2659, 60.2413, 95.4781,  ..., 60.2418, 60.2400, 60.2432],
         [89.7589, 57.7161, 89.9661,  ..., 57.7169, 57.7153, 57.7189],
         [92.8971, 58.3084, 93.5553,  ..., 58.3086, 58.3065, 58.3105]]],
       grad_fn=<UnsafeViewBackward0>)

In [15]:
torch.allclose(x.logits, y[0])

True

In [16]:
tok_emb = model.transformer.wte(input_ids) # token embeddings of shape (b, t, n_embd)
inputs_embeds = model_hf.gpt_neox.embed_in(input_ids)

torch.allclose(tok_emb, inputs_embeds)

True

In [18]:

hidden = tok_emb
for i in range(6):
    print(i)
    # h1 = model.transformer.h[i](hidden)
    # h2 = model_hf.gpt_neox.layers[i](hidden)[0]
    
    v1 = model.transformer.h[i].ln_1(hidden)
    v2 = model_hf.gpt_neox.layers[i].input_layernorm(hidden)
    
    h1 = model.transformer.h[i].attn(v1)
    h2 = model_hf.gpt_neox.layers[i].attention(v2)[0]
    
    # print(h1, h2)
    print(torch.allclose(h1, h2))
    hidden = h1
    # print(h1 - h2)

0
True
1
True
2
True
3
True
4
True
5
True
