In [8]:
import torch
from transformers import GPT2LMHeadModel

# Load the pretrained GPT-2 model (in raw PyTorch)
model_name = "gpt2"  # Options: 'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'
model = GPT2LMHeadModel.from_pretrained(model_name)

# Save the model's state_dict as a local PyTorch checkpoint
torch.save(model.state_dict(), "gpt2_pytorch_model.pt")

print("Model weights saved as 'gpt2_pytorch_model.pt'.")

Model weights saved as 'gpt2_pytorch_model.pt'.


In [9]:
import torch
import torch.nn as nn

class GPT2(nn.Module):
    def __init__(self, vocab_size, n_embd, n_layer, n_head):
        super(GPT2, self).__init__()
        self.embed = nn.Embedding(vocab_size, n_embd)
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=n_embd, nhead=n_head)
            for _ in range(n_layer)
        ])
        self.ln_f = nn.LayerNorm(n_embd)
        self.head = nn.Linear(n_embd, vocab_size, bias=False)

    def forward(self, x):
        x = self.embed(x)
        for layer in self.layers:
            x = layer(x)
        x = self.ln_f(x)
        return self.head(x)

# GPT-2 Small Config (as an example)
config = {
    "vocab_size": 50257,
    "n_embd": 768,
    "n_layer": 12,
    "n_head": 12,
}

# Initialize the model
gpt2_model = GPT2(
    vocab_size=config["vocab_size"],
    n_embd=config["n_embd"],
    n_layer=config["n_layer"],
    n_head=config["n_head"],
)

# Load the saved state_dict
state_dict = torch.load("gpt2_pytorch_model.pt")
gpt2_model.load_state_dict(state_dict)

print("Model loaded successfully!")


  state_dict = torch.load("gpt2_pytorch_model.pt")


RuntimeError: Error(s) in loading state_dict for GPT2:
	Missing key(s) in state_dict: "embed.weight", "layers.0.self_attn.in_proj_weight", "layers.0.self_attn.in_proj_bias", "layers.0.self_attn.out_proj.weight", "layers.0.self_attn.out_proj.bias", "layers.0.linear1.weight", "layers.0.linear1.bias", "layers.0.linear2.weight", "layers.0.linear2.bias", "layers.0.norm1.weight", "layers.0.norm1.bias", "layers.0.norm2.weight", "layers.0.norm2.bias", "layers.1.self_attn.in_proj_weight", "layers.1.self_attn.in_proj_bias", "layers.1.self_attn.out_proj.weight", "layers.1.self_attn.out_proj.bias", "layers.1.linear1.weight", "layers.1.linear1.bias", "layers.1.linear2.weight", "layers.1.linear2.bias", "layers.1.norm1.weight", "layers.1.norm1.bias", "layers.1.norm2.weight", "layers.1.norm2.bias", "layers.2.self_attn.in_proj_weight", "layers.2.self_attn.in_proj_bias", "layers.2.self_attn.out_proj.weight", "layers.2.self_attn.out_proj.bias", "layers.2.linear1.weight", "layers.2.linear1.bias", "layers.2.linear2.weight", "layers.2.linear2.bias", "layers.2.norm1.weight", "layers.2.norm1.bias", "layers.2.norm2.weight", "layers.2.norm2.bias", "layers.3.self_attn.in_proj_weight", "layers.3.self_attn.in_proj_bias", "layers.3.self_attn.out_proj.weight", "layers.3.self_attn.out_proj.bias", "layers.3.linear1.weight", "layers.3.linear1.bias", "layers.3.linear2.weight", "layers.3.linear2.bias", "layers.3.norm1.weight", "layers.3.norm1.bias", "layers.3.norm2.weight", "layers.3.norm2.bias", "layers.4.self_attn.in_proj_weight", "layers.4.self_attn.in_proj_bias", "layers.4.self_attn.out_proj.weight", "layers.4.self_attn.out_proj.bias", "layers.4.linear1.weight", "layers.4.linear1.bias", "layers.4.linear2.weight", "layers.4.linear2.bias", "layers.4.norm1.weight", "layers.4.norm1.bias", "layers.4.norm2.weight", "layers.4.norm2.bias", "layers.5.self_attn.in_proj_weight", "layers.5.self_attn.in_proj_bias", "layers.5.self_attn.out_proj.weight", "layers.5.self_attn.out_proj.bias", "layers.5.linear1.weight", "layers.5.linear1.bias", "layers.5.linear2.weight", "layers.5.linear2.bias", "layers.5.norm1.weight", "layers.5.norm1.bias", "layers.5.norm2.weight", "layers.5.norm2.bias", "layers.6.self_attn.in_proj_weight", "layers.6.self_attn.in_proj_bias", "layers.6.self_attn.out_proj.weight", "layers.6.self_attn.out_proj.bias", "layers.6.linear1.weight", "layers.6.linear1.bias", "layers.6.linear2.weight", "layers.6.linear2.bias", "layers.6.norm1.weight", "layers.6.norm1.bias", "layers.6.norm2.weight", "layers.6.norm2.bias", "layers.7.self_attn.in_proj_weight", "layers.7.self_attn.in_proj_bias", "layers.7.self_attn.out_proj.weight", "layers.7.self_attn.out_proj.bias", "layers.7.linear1.weight", "layers.7.linear1.bias", "layers.7.linear2.weight", "layers.7.linear2.bias", "layers.7.norm1.weight", "layers.7.norm1.bias", "layers.7.norm2.weight", "layers.7.norm2.bias", "layers.8.self_attn.in_proj_weight", "layers.8.self_attn.in_proj_bias", "layers.8.self_attn.out_proj.weight", "layers.8.self_attn.out_proj.bias", "layers.8.linear1.weight", "layers.8.linear1.bias", "layers.8.linear2.weight", "layers.8.linear2.bias", "layers.8.norm1.weight", "layers.8.norm1.bias", "layers.8.norm2.weight", "layers.8.norm2.bias", "layers.9.self_attn.in_proj_weight", "layers.9.self_attn.in_proj_bias", "layers.9.self_attn.out_proj.weight", "layers.9.self_attn.out_proj.bias", "layers.9.linear1.weight", "layers.9.linear1.bias", "layers.9.linear2.weight", "layers.9.linear2.bias", "layers.9.norm1.weight", "layers.9.norm1.bias", "layers.9.norm2.weight", "layers.9.norm2.bias", "layers.10.self_attn.in_proj_weight", "layers.10.self_attn.in_proj_bias", "layers.10.self_attn.out_proj.weight", "layers.10.self_attn.out_proj.bias", "layers.10.linear1.weight", "layers.10.linear1.bias", "layers.10.linear2.weight", "layers.10.linear2.bias", "layers.10.norm1.weight", "layers.10.norm1.bias", "layers.10.norm2.weight", "layers.10.norm2.bias", "layers.11.self_attn.in_proj_weight", "layers.11.self_attn.in_proj_bias", "layers.11.self_attn.out_proj.weight", "layers.11.self_attn.out_proj.bias", "layers.11.linear1.weight", "layers.11.linear1.bias", "layers.11.linear2.weight", "layers.11.linear2.bias", "layers.11.norm1.weight", "layers.11.norm1.bias", "layers.11.norm2.weight", "layers.11.norm2.bias", "ln_f.weight", "ln_f.bias", "head.weight". 
	Unexpected key(s) in state_dict: "transformer.wte.weight", "transformer.wpe.weight", "transformer.h.0.ln_1.weight", "transformer.h.0.ln_1.bias", "transformer.h.0.attn.c_attn.weight", "transformer.h.0.attn.c_attn.bias", "transformer.h.0.attn.c_proj.weight", "transformer.h.0.attn.c_proj.bias", "transformer.h.0.ln_2.weight", "transformer.h.0.ln_2.bias", "transformer.h.0.mlp.c_fc.weight", "transformer.h.0.mlp.c_fc.bias", "transformer.h.0.mlp.c_proj.weight", "transformer.h.0.mlp.c_proj.bias", "transformer.h.1.ln_1.weight", "transformer.h.1.ln_1.bias", "transformer.h.1.attn.c_attn.weight", "transformer.h.1.attn.c_attn.bias", "transformer.h.1.attn.c_proj.weight", "transformer.h.1.attn.c_proj.bias", "transformer.h.1.ln_2.weight", "transformer.h.1.ln_2.bias", "transformer.h.1.mlp.c_fc.weight", "transformer.h.1.mlp.c_fc.bias", "transformer.h.1.mlp.c_proj.weight", "transformer.h.1.mlp.c_proj.bias", "transformer.h.2.ln_1.weight", "transformer.h.2.ln_1.bias", "transformer.h.2.attn.c_attn.weight", "transformer.h.2.attn.c_attn.bias", "transformer.h.2.attn.c_proj.weight", "transformer.h.2.attn.c_proj.bias", "transformer.h.2.ln_2.weight", "transformer.h.2.ln_2.bias", "transformer.h.2.mlp.c_fc.weight", "transformer.h.2.mlp.c_fc.bias", "transformer.h.2.mlp.c_proj.weight", "transformer.h.2.mlp.c_proj.bias", "transformer.h.3.ln_1.weight", "transformer.h.3.ln_1.bias", "transformer.h.3.attn.c_attn.weight", "transformer.h.3.attn.c_attn.bias", "transformer.h.3.attn.c_proj.weight", "transformer.h.3.attn.c_proj.bias", "transformer.h.3.ln_2.weight", "transformer.h.3.ln_2.bias", "transformer.h.3.mlp.c_fc.weight", "transformer.h.3.mlp.c_fc.bias", "transformer.h.3.mlp.c_proj.weight", "transformer.h.3.mlp.c_proj.bias", "transformer.h.4.ln_1.weight", "transformer.h.4.ln_1.bias", "transformer.h.4.attn.c_attn.weight", "transformer.h.4.attn.c_attn.bias", "transformer.h.4.attn.c_proj.weight", "transformer.h.4.attn.c_proj.bias", "transformer.h.4.ln_2.weight", "transformer.h.4.ln_2.bias", "transformer.h.4.mlp.c_fc.weight", "transformer.h.4.mlp.c_fc.bias", "transformer.h.4.mlp.c_proj.weight", "transformer.h.4.mlp.c_proj.bias", "transformer.h.5.ln_1.weight", "transformer.h.5.ln_1.bias", "transformer.h.5.attn.c_attn.weight", "transformer.h.5.attn.c_attn.bias", "transformer.h.5.attn.c_proj.weight", "transformer.h.5.attn.c_proj.bias", "transformer.h.5.ln_2.weight", "transformer.h.5.ln_2.bias", "transformer.h.5.mlp.c_fc.weight", "transformer.h.5.mlp.c_fc.bias", "transformer.h.5.mlp.c_proj.weight", "transformer.h.5.mlp.c_proj.bias", "transformer.h.6.ln_1.weight", "transformer.h.6.ln_1.bias", "transformer.h.6.attn.c_attn.weight", "transformer.h.6.attn.c_attn.bias", "transformer.h.6.attn.c_proj.weight", "transformer.h.6.attn.c_proj.bias", "transformer.h.6.ln_2.weight", "transformer.h.6.ln_2.bias", "transformer.h.6.mlp.c_fc.weight", "transformer.h.6.mlp.c_fc.bias", "transformer.h.6.mlp.c_proj.weight", "transformer.h.6.mlp.c_proj.bias", "transformer.h.7.ln_1.weight", "transformer.h.7.ln_1.bias", "transformer.h.7.attn.c_attn.weight", "transformer.h.7.attn.c_attn.bias", "transformer.h.7.attn.c_proj.weight", "transformer.h.7.attn.c_proj.bias", "transformer.h.7.ln_2.weight", "transformer.h.7.ln_2.bias", "transformer.h.7.mlp.c_fc.weight", "transformer.h.7.mlp.c_fc.bias", "transformer.h.7.mlp.c_proj.weight", "transformer.h.7.mlp.c_proj.bias", "transformer.h.8.ln_1.weight", "transformer.h.8.ln_1.bias", "transformer.h.8.attn.c_attn.weight", "transformer.h.8.attn.c_attn.bias", "transformer.h.8.attn.c_proj.weight", "transformer.h.8.attn.c_proj.bias", "transformer.h.8.ln_2.weight", "transformer.h.8.ln_2.bias", "transformer.h.8.mlp.c_fc.weight", "transformer.h.8.mlp.c_fc.bias", "transformer.h.8.mlp.c_proj.weight", "transformer.h.8.mlp.c_proj.bias", "transformer.h.9.ln_1.weight", "transformer.h.9.ln_1.bias", "transformer.h.9.attn.c_attn.weight", "transformer.h.9.attn.c_attn.bias", "transformer.h.9.attn.c_proj.weight", "transformer.h.9.attn.c_proj.bias", "transformer.h.9.ln_2.weight", "transformer.h.9.ln_2.bias", "transformer.h.9.mlp.c_fc.weight", "transformer.h.9.mlp.c_fc.bias", "transformer.h.9.mlp.c_proj.weight", "transformer.h.9.mlp.c_proj.bias", "transformer.h.10.ln_1.weight", "transformer.h.10.ln_1.bias", "transformer.h.10.attn.c_attn.weight", "transformer.h.10.attn.c_attn.bias", "transformer.h.10.attn.c_proj.weight", "transformer.h.10.attn.c_proj.bias", "transformer.h.10.ln_2.weight", "transformer.h.10.ln_2.bias", "transformer.h.10.mlp.c_fc.weight", "transformer.h.10.mlp.c_fc.bias", "transformer.h.10.mlp.c_proj.weight", "transformer.h.10.mlp.c_proj.bias", "transformer.h.11.ln_1.weight", "transformer.h.11.ln_1.bias", "transformer.h.11.attn.c_attn.weight", "transformer.h.11.attn.c_attn.bias", "transformer.h.11.attn.c_proj.weight", "transformer.h.11.attn.c_proj.bias", "transformer.h.11.ln_2.weight", "transformer.h.11.ln_2.bias", "transformer.h.11.mlp.c_fc.weight", "transformer.h.11.mlp.c_fc.bias", "transformer.h.11.mlp.c_proj.weight", "transformer.h.11.mlp.c_proj.bias", "transformer.ln_f.weight", "transformer.ln_f.bias", "lm_head.weight". 

In [11]:
state_dict.keys()

odict_keys(['transformer.wte.weight', 'transformer.wpe.weight', 'transformer.h.0.ln_1.weight', 'transformer.h.0.ln_1.bias', 'transformer.h.0.attn.c_attn.weight', 'transformer.h.0.attn.c_attn.bias', 'transformer.h.0.attn.c_proj.weight', 'transformer.h.0.attn.c_proj.bias', 'transformer.h.0.ln_2.weight', 'transformer.h.0.ln_2.bias', 'transformer.h.0.mlp.c_fc.weight', 'transformer.h.0.mlp.c_fc.bias', 'transformer.h.0.mlp.c_proj.weight', 'transformer.h.0.mlp.c_proj.bias', 'transformer.h.1.ln_1.weight', 'transformer.h.1.ln_1.bias', 'transformer.h.1.attn.c_attn.weight', 'transformer.h.1.attn.c_attn.bias', 'transformer.h.1.attn.c_proj.weight', 'transformer.h.1.attn.c_proj.bias', 'transformer.h.1.ln_2.weight', 'transformer.h.1.ln_2.bias', 'transformer.h.1.mlp.c_fc.weight', 'transformer.h.1.mlp.c_fc.bias', 'transformer.h.1.mlp.c_proj.weight', 'transformer.h.1.mlp.c_proj.bias', 'transformer.h.2.ln_1.weight', 'transformer.h.2.ln_1.bias', 'transformer.h.2.attn.c_attn.weight', 'transformer.h.2.attn.

In [13]:
gpt2_model.state_dict().keys()

odict_keys(['embed.weight', 'layers.0.self_attn.in_proj_weight', 'layers.0.self_attn.in_proj_bias', 'layers.0.self_attn.out_proj.weight', 'layers.0.self_attn.out_proj.bias', 'layers.0.linear1.weight', 'layers.0.linear1.bias', 'layers.0.linear2.weight', 'layers.0.linear2.bias', 'layers.0.norm1.weight', 'layers.0.norm1.bias', 'layers.0.norm2.weight', 'layers.0.norm2.bias', 'layers.1.self_attn.in_proj_weight', 'layers.1.self_attn.in_proj_bias', 'layers.1.self_attn.out_proj.weight', 'layers.1.self_attn.out_proj.bias', 'layers.1.linear1.weight', 'layers.1.linear1.bias', 'layers.1.linear2.weight', 'layers.1.linear2.bias', 'layers.1.norm1.weight', 'layers.1.norm1.bias', 'layers.1.norm2.weight', 'layers.1.norm2.bias', 'layers.2.self_attn.in_proj_weight', 'layers.2.self_attn.in_proj_bias', 'layers.2.self_attn.out_proj.weight', 'layers.2.self_attn.out_proj.bias', 'layers.2.linear1.weight', 'layers.2.linear1.bias', 'layers.2.linear2.weight', 'layers.2.linear2.bias', 'layers.2.norm1.weight', 'laye

In [16]:
from transformers import GPT2LMHeadModel

model = GPT2LMHeadModel.from_pretrained('gpt2')
torch.load('gpt2-pytorch/model.ckpt')


  torch.load('gpt2-pytorch/model.ckpt')


UnpicklingError: invalid load key, '\xef'.