In [1]:
from modeling_gpt_neox import GPTNeoXForCausalLM
model_hf = GPTNeoXForCausalLM.from_pretrained("trl-internal-testing/tiny-random-GPTNeoXForCausalLM")

In [2]:
gpt_neox_config = model_hf.config
gpt_neox_config

GPTNeoXConfig {
  "_name_or_path": "trl-internal-testing/tiny-random-GPTNeoXForCausalLM",
  "architectures": [
    "GPTNeoXForCausalLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "eos_token_id": 0,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 32,
  "initializer_range": 0.02,
  "intermediate_size": 37,
  "is_decoder": true,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 512,
  "model_type": "gpt_neox",
  "num_attention_heads": 4,
  "num_hidden_layers": 5,
  "rotary_emb_base": 10000,
  "rotary_pct": 0.25,
  "tie_word_embeddings": false,
  "torch_dtype": "float32",
  "transformers_version": "4.25.1",
  "type_vocab_size": 16,
  "use_cache": true,
  "use_parallel_residual": true,
  "vocab_size": 1024
}

In [3]:
from model import GPTConfig, GPT

def convert_config(gpt_neox_config):
    assert gpt_neox_config.use_parallel_residual
    
    return GPTConfig(
        n_embd = gpt_neox_config.hidden_size,
        n_head = gpt_neox_config.num_attention_heads,
        n_layer = gpt_neox_config.num_hidden_layers,
        vocab_size = gpt_neox_config.vocab_size,
        block_size = gpt_neox_config.max_position_embeddings,
        bias = True,
        dropout= 0.0,
        
        n_embd_proj = gpt_neox_config.intermediate_size,
        rotary_pct = gpt_neox_config.rotary_pct,
        use_parallel_residual = gpt_neox_config.use_parallel_residual,
    )
    
gpt_config = convert_config(gpt_neox_config)

In [4]:
config = convert_config(gpt_neox_config)
model = GPT(config)
sd = model.state_dict()


number of parameters: 0.10M


In [5]:
sd_keys = sd.keys()
sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
sd_keys

['transformer.wte.weight',
 'transformer.h.0.ln_1.weight',
 'transformer.h.0.ln_1.bias',
 'transformer.h.0.attn.c_attn.weight',
 'transformer.h.0.attn.c_attn.bias',
 'transformer.h.0.attn.c_proj.weight',
 'transformer.h.0.attn.c_proj.bias',
 'transformer.h.0.ln_2.weight',
 'transformer.h.0.ln_2.bias',
 'transformer.h.0.mlp.c_fc.weight',
 'transformer.h.0.mlp.c_fc.bias',
 'transformer.h.0.mlp.c_proj.weight',
 'transformer.h.0.mlp.c_proj.bias',
 'transformer.h.1.ln_1.weight',
 'transformer.h.1.ln_1.bias',
 'transformer.h.1.attn.c_attn.weight',
 'transformer.h.1.attn.c_attn.bias',
 'transformer.h.1.attn.c_proj.weight',
 'transformer.h.1.attn.c_proj.bias',
 'transformer.h.1.ln_2.weight',
 'transformer.h.1.ln_2.bias',
 'transformer.h.1.mlp.c_fc.weight',
 'transformer.h.1.mlp.c_fc.bias',
 'transformer.h.1.mlp.c_proj.weight',
 'transformer.h.1.mlp.c_proj.bias',
 'transformer.h.2.ln_1.weight',
 'transformer.h.2.ln_1.bias',
 'transformer.h.2.attn.c_attn.weight',
 'transformer.h.2.attn.c_attn.bi

In [6]:

# init a huggingface/transformers model
sd_hf = model_hf.state_dict()


In [16]:

# copy while ensuring all of the parameters are aligned and match in names and shapes
sd_keys_hf = sd_hf.keys()
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attention.masked_bias')] # ignore these, just a buffer
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attention.bias')] # same, just the mask (buffer)
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.rotary_emb.inv_freq')] # same, just the mask (buffer)


In [18]:
sd_keys_hf

['gpt_neox.embed_in.weight',
 'gpt_neox.layers.0.input_layernorm.weight',
 'gpt_neox.layers.0.input_layernorm.bias',
 'gpt_neox.layers.0.post_attention_layernorm.weight',
 'gpt_neox.layers.0.post_attention_layernorm.bias',
 'gpt_neox.layers.0.attention.query_key_value.weight',
 'gpt_neox.layers.0.attention.query_key_value.bias',
 'gpt_neox.layers.0.attention.dense.weight',
 'gpt_neox.layers.0.attention.dense.bias',
 'gpt_neox.layers.0.mlp.dense_h_to_4h.weight',
 'gpt_neox.layers.0.mlp.dense_h_to_4h.bias',
 'gpt_neox.layers.0.mlp.dense_4h_to_h.weight',
 'gpt_neox.layers.0.mlp.dense_4h_to_h.bias',
 'gpt_neox.layers.1.input_layernorm.weight',
 'gpt_neox.layers.1.input_layernorm.bias',
 'gpt_neox.layers.1.post_attention_layernorm.weight',
 'gpt_neox.layers.1.post_attention_layernorm.bias',
 'gpt_neox.layers.1.attention.query_key_value.weight',
 'gpt_neox.layers.1.attention.query_key_value.bias',
 'gpt_neox.layers.1.attention.dense.weight',
 'gpt_neox.layers.1.attention.dense.bias',
 'gpt_n

In [28]:
def rename_key(k):
    k = k.replace('gpt_neox.layers', 'transformer.h')
    # Attention
    k = k.replace('.attention.dense.', '.attn.c_proj.')
    k = k.replace('.attention.query_key_value.', '.attn.c_attn.')
    # MLP
    k = k.replace('.mlp.dense_h_to_4h.', '.mlp.c_fc.')
    k = k.replace('.mlp.dense_4h_to_h.', '.mlp.c_proj.')
    # LayerNorm
    k = k.replace('.input_layernorm.', '.ln_1.')
    k = k.replace('.post_attention_layernorm.', '.ln_2.')
    # Embedding
    k = k.replace('gpt_neox.embed_in.', 'transformer.wte.')
    k = k.replace('gpt_neox.final_layer_norm.', 'transformer.ln_f.')
    k = k.replace('embed_out.', 'lm_head.')
    return k


In [29]:
import torch

assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
for k in sd_keys_hf:
    # vanilla copy over the other parameters
    new_k = rename_key(k)
    print(new_k, k)
    assert sd_hf[k].shape == sd[new_k].shape
    with torch.no_grad():
        sd[new_k].copy_(sd_hf[k])



transformer.wte.weight gpt_neox.embed_in.weight
transformer.h.0.ln_1.weight gpt_neox.layers.0.input_layernorm.weight
transformer.h.0.ln_1.bias gpt_neox.layers.0.input_layernorm.bias
transformer.h.0.ln_2.weight gpt_neox.layers.0.post_attention_layernorm.weight
transformer.h.0.ln_2.bias gpt_neox.layers.0.post_attention_layernorm.bias
transformer.h.0.attn.c_attn.weight gpt_neox.layers.0.attention.query_key_value.weight
transformer.h.0.attn.c_attn.bias gpt_neox.layers.0.attention.query_key_value.bias
transformer.h.0.attn.c_proj.weight gpt_neox.layers.0.attention.dense.weight
transformer.h.0.attn.c_proj.bias gpt_neox.layers.0.attention.dense.bias
transformer.h.0.mlp.c_fc.weight gpt_neox.layers.0.mlp.dense_h_to_4h.weight


AssertionError: 