In [1]:
import torch

from transformers import AutoModelForCausalLM

In [2]:
model = AutoModelForCausalLM.from_pretrained("Dahoas/gpt2-rm-static")
model

Some weights of the model checkpoint at Dahoas/gpt2-rm-static were not used when initializing GPT2LMHeadModel: ['score.weight']
- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2LMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [9]:
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"{n_params/1e6:.1f} M trainable params")
n_params = sum(p.numel() for p in model.parameters())
print(f"{n_params/1e6:.1f} M params")

124.4 M trainable params
124.4 M params


## Check how keys match

In [10]:
tmp_1 = [name for name, _ in model.named_parameters()]
len(tmp_1), tmp_1[:5]

(148,
 ['transformer.wte.weight',
  'transformer.wpe.weight',
  'transformer.h.0.ln_1.weight',
  'transformer.h.0.ln_1.bias',
  'transformer.h.0.attn.c_attn.weight'])

In [14]:
weights_path = (
    "../../../.hf_cache/hub/models--Dahoas--gpt2-rm-static/"
    "snapshots/07d6ffc61e51b28864878a2a97c2788c846e63d8/pytorch_model.bin"
)

state_dict = torch.load(weights_path)
type(state_dict), list(state_dict.keys())[:5]

(dict,
 ['transformer.wte.weight',
  'transformer.wpe.weight',
  'transformer.h.0.ln_1.weight',
  'transformer.h.0.ln_1.bias',
  'transformer.h.0.attn.bias'])

In [15]:
tmp_2 = list(state_dict.keys())
len(tmp_2), tmp_2[:5]

(173,
 ['transformer.wte.weight',
  'transformer.wpe.weight',
  'transformer.h.0.ln_1.weight',
  'transformer.h.0.ln_1.bias',
  'transformer.h.0.attn.bias'])

In [16]:
list(set(tmp_1) - set(tmp_2)) # all model keys are in the loaded state_dict

[]

In [17]:
# Some state dict keys are not in the model
tmp_3 = sorted(set(tmp_2) - set(tmp_1))
tmp_3

['score.weight',
 'transformer.h.0.attn.bias',
 'transformer.h.0.attn.masked_bias',
 'transformer.h.1.attn.bias',
 'transformer.h.1.attn.masked_bias',
 'transformer.h.10.attn.bias',
 'transformer.h.10.attn.masked_bias',
 'transformer.h.11.attn.bias',
 'transformer.h.11.attn.masked_bias',
 'transformer.h.2.attn.bias',
 'transformer.h.2.attn.masked_bias',
 'transformer.h.3.attn.bias',
 'transformer.h.3.attn.masked_bias',
 'transformer.h.4.attn.bias',
 'transformer.h.4.attn.masked_bias',
 'transformer.h.5.attn.bias',
 'transformer.h.5.attn.masked_bias',
 'transformer.h.6.attn.bias',
 'transformer.h.6.attn.masked_bias',
 'transformer.h.7.attn.bias',
 'transformer.h.7.attn.masked_bias',
 'transformer.h.8.attn.bias',
 'transformer.h.8.attn.masked_bias',
 'transformer.h.9.attn.bias',
 'transformer.h.9.attn.masked_bias']

In [20]:
count = 0

for name, param in state_dict.items():
    if name in tmp_3:
        print(name, param.shape)
        count += param.numel()

count / 1e6 # 12M params are not in the model

transformer.h.0.attn.bias torch.Size([1, 1, 1024, 1024])
transformer.h.0.attn.masked_bias torch.Size([])
transformer.h.1.attn.bias torch.Size([1, 1, 1024, 1024])
transformer.h.1.attn.masked_bias torch.Size([])
transformer.h.2.attn.bias torch.Size([1, 1, 1024, 1024])
transformer.h.2.attn.masked_bias torch.Size([])
transformer.h.3.attn.bias torch.Size([1, 1, 1024, 1024])
transformer.h.3.attn.masked_bias torch.Size([])
transformer.h.4.attn.bias torch.Size([1, 1, 1024, 1024])
transformer.h.4.attn.masked_bias torch.Size([])
transformer.h.5.attn.bias torch.Size([1, 1, 1024, 1024])
transformer.h.5.attn.masked_bias torch.Size([])
transformer.h.6.attn.bias torch.Size([1, 1, 1024, 1024])
transformer.h.6.attn.masked_bias torch.Size([])
transformer.h.7.attn.bias torch.Size([1, 1, 1024, 1024])
transformer.h.7.attn.masked_bias torch.Size([])
transformer.h.8.attn.bias torch.Size([1, 1, 1024, 1024])
transformer.h.8.attn.masked_bias torch.Size([])
transformer.h.9.attn.bias torch.Size([1, 1, 1024, 1024]

12.583692

I think they are just buffers and can safely be ignored.