In [1]:
from modeling_gpt_neox import GPTNeoXForCausalLM
model_hf = GPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-70m")

16 2048 10000
tensor([1.0000e+00, 3.1623e-01, 1.0000e-01, 3.1623e-02, 1.0000e-02, 3.1623e-03,
        1.0000e-03, 3.1623e-04])
16 2048 10000
tensor([1.0000e+00, 3.1623e-01, 1.0000e-01, 3.1623e-02, 1.0000e-02, 3.1623e-03,
        1.0000e-03, 3.1623e-04])
16 2048 10000
tensor([1.0000e+00, 3.1623e-01, 1.0000e-01, 3.1623e-02, 1.0000e-02, 3.1623e-03,
        1.0000e-03, 3.1623e-04])
16 2048 10000
tensor([1.0000e+00, 3.1623e-01, 1.0000e-01, 3.1623e-02, 1.0000e-02, 3.1623e-03,
        1.0000e-03, 3.1623e-04])
16 2048 10000
tensor([1.0000e+00, 3.1623e-01, 1.0000e-01, 3.1623e-02, 1.0000e-02, 3.1623e-03,
        1.0000e-03, 3.1623e-04])
16 2048 10000
tensor([1.0000e+00, 3.1623e-01, 1.0000e-01, 3.1623e-02, 1.0000e-02, 3.1623e-03,
        1.0000e-03, 3.1623e-04])


In [2]:
gpt_neox_config = model_hf.config
gpt_neox_config

GPTNeoXConfig {
  "_name_or_path": "EleutherAI/pythia-70m",
  "architectures": [
    "GPTNeoXForCausalLM"
  ],
  "bos_token_id": 0,
  "eos_token_id": 0,
  "hidden_act": "gelu",
  "hidden_size": 512,
  "initializer_range": 0.02,
  "intermediate_size": 2048,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 2048,
  "model_type": "gpt_neox",
  "num_attention_heads": 8,
  "num_hidden_layers": 6,
  "rotary_emb_base": 10000,
  "rotary_pct": 0.25,
  "tie_word_embeddings": false,
  "torch_dtype": "float16",
  "transformers_version": "4.25.1",
  "use_cache": true,
  "use_parallel_residual": true,
  "vocab_size": 50304
}

In [3]:
from model import GPTConfig, GPT

def convert_config(gpt_neox_config):
    # assert gpt_neox_config.use_parallel_residual
    
    return GPTConfig(
        n_embd = gpt_neox_config.hidden_size,
        n_head = gpt_neox_config.num_attention_heads,
        n_layer = gpt_neox_config.num_hidden_layers,
        vocab_size = gpt_neox_config.vocab_size,
        block_size = gpt_neox_config.max_position_embeddings,
        bias = True,
        dropout= 0.0,
        
        n_embd_proj = gpt_neox_config.intermediate_size,
        rotary_pct = gpt_neox_config.rotary_pct,
        
    )
    
gpt_config = convert_config(gpt_neox_config)

In [4]:
config = convert_config(gpt_neox_config)
model = GPT(config)
sd = model.state_dict()


number of parameters: 70.43M


In [5]:
sd_keys = sd.keys()
sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param

sd_keys = [k for k in sd_keys if not k.endswith('.attn.rotary_emb.freqs')] # discard this mask / buffer, not a param
sd_keys = [k for k in sd_keys if not k.endswith('.attn.rotary_emb.scale')] # discard this mask / buffer, not a param
sd_keys = [k for k in sd_keys if not k.endswith('.attn.rotary_emb.inv_freq')] # discard this mask / buffer, not a param

In [6]:

# init a huggingface/transformers model
sd_hf = model_hf.state_dict()


In [7]:

# copy while ensuring all of the parameters are aligned and match in names and shapes
sd_keys_hf = sd_hf.keys()
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attention.masked_bias')] # ignore these, just a buffer
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attention.bias')] # same, just the mask (buffer)
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.rotary_emb.inv_freq')] # same, just the mask (buffer)


In [8]:
def rename_key(k):
    k = k.replace('gpt_neox.layers', 'transformer.h')
    # Attention
    k = k.replace('.attention.dense.', '.attn.c_proj.')
    k = k.replace('.attention.query_key_value.', '.attn.c_attn.')
    # MLP
    k = k.replace('.mlp.dense_h_to_4h.', '.mlp.c_fc.')
    k = k.replace('.mlp.dense_4h_to_h.', '.mlp.c_proj.')
    # LayerNorm
    k = k.replace('.input_layernorm.', '.ln_1.')
    k = k.replace('.post_attention_layernorm.', '.ln_2.')
    # Embedding
    k = k.replace('gpt_neox.embed_in.', 'transformer.wte.')
    k = k.replace('gpt_neox.final_layer_norm.', 'transformer.ln_f.')
    k = k.replace('embed_out.', 'lm_head.')
    return k


In [9]:
import torch

assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
for k in sd_keys_hf:
    # vanilla copy over the other parameters
    new_k = rename_key(k)
    # print(new_k, k)
    assert sd_hf[k].shape == sd[new_k].shape
    with torch.no_grad():
        sd[new_k].copy_(sd_hf[k])


In [10]:
input_ids = torch.randint(0, 1000, (1, 512))
target_ids = torch.randint(0, 1000, (1, 512))

model_hf.train()
model.train()

x = model_hf(input_ids, labels=target_ids)
y = model(input_ids, target_ids)

In [11]:
x.logits

tensor([[[ 91.9129,  58.7314,  95.1513,  ...,  58.7315,  58.7294,  58.7327],
         [ 96.3681,  56.4360, 101.9045,  ...,  56.4365,  56.4342,  56.4373],
         [ 93.8861,  58.3406,  96.1099,  ...,  58.3406,  58.3387,  58.3415],
         ...,
         [ 95.4248,  59.3779,  97.0990,  ...,  59.3783,  59.3761,  59.3797],
         [ 94.8364,  57.9552,  94.4135,  ...,  57.9558,  57.9535,  57.9572],
         [ 93.7682,  58.6891,  95.3752,  ...,  58.6896,  58.6872,  58.6908]]],
       grad_fn=<UnsafeViewBackward0>)

In [12]:
y[0]

tensor([[[ 91.9129,  58.7314,  95.1513,  ...,  58.7315,  58.7294,  58.7327],
         [ 95.6873,  56.4097, 101.2286,  ...,  56.4099,  56.4079,  56.4108],
         [ 95.5324,  58.9605,  96.6613,  ...,  58.9607,  58.9585,  58.9614],
         ...,
         [ 89.3815,  55.3356,  87.5559,  ...,  55.3355,  55.3339,  55.3378],
         [ 85.9385,  54.4147,  82.0016,  ...,  54.4150,  54.4132,  54.4167],
         [ 94.9651,  58.7230,  97.6076,  ...,  58.7236,  58.7211,  58.7245]]],
       grad_fn=<UnsafeViewBackward0>)

In [13]:
torch.allclose(x.logits, y[0])

False

In [14]:
tok_emb = model.transformer.wte(input_ids) # token embeddings of shape (b, t, n_embd)
inputs_embeds = model_hf.gpt_neox.embed_in(input_ids)

torch.allclose(tok_emb, inputs_embeds)

True

In [15]:

hidden = tok_emb
for i in range(1):
    print(i)
    # h1 = model.transformer.h[i](hidden)
    # h2 = model_hf.gpt_neox.layers[i](hidden)[0]
    
    v1 = model.transformer.h[i].ln_1(hidden)
    v2 = model_hf.gpt_neox.layers[i].input_layernorm(hidden)
    
    h1 = model.transformer.h[i].attn(v1)
    h2 = model_hf.gpt_neox.layers[i].attention(v2)[0]
    
    # print(h1, h2)
    print(torch.allclose(h1, h2))
    # print(h1 - h2)

0
False
