<table style="width:100%">
<tr>
<td style="vertical-align:middle; text-align:left;">
<font size="2">
Supplementary code for the <a href="http://mng.bz/orYv">Build a Large Language Model From Scratch</a> book by <a href="https://sebastianraschka.com">Sebastian Raschka</a><br>
<br>Code repository: <a href="https://github.com/rasbt/LLMs-from-scratch">https://github.com/rasbt/LLMs-from-scratch</a>
</font>
</td>
<td style="vertical-align:middle; text-align:left;">
<a href="http://mng.bz/orYv"><img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp" width="100px"></a>
</td>
</tr>
</table>

# 第5章附加代码

## 从PyTorch state dict中加载替代权重

- 在主章节中，直接从OpenAI加载了GPT模型权重
- 本notebook提供了替代的权重加载代码，用于从[Hugging Face Model Hub](https://huggingface.co/docs/hub/en/models-the-hub)上上传的 PyTorch state dict文件中加载模型权重。这些状态字典文件是从原始TensorFlow文件中创建的，上传地址为[https://huggingface.co/rasbt/gpt2-from-scratch-pytorch](https://huggingface.co/rasbt/gpt2-from-scratch-pytorch)
- 在概念上，这与从第5章中描述的状态字典方法加载PyTorch模型的权重相同

```python
state_dict = torch.load("model_state_dict.pth")
model.load_state_dict(state_dict) 
```

### 选择模型

In [1]:
from importlib.metadata import version

pkgs = ["torch"]
for p in pkgs:
    print(f"{p} version: {version(p)}")

torch version: 2.2.0+cu118


In [2]:
BASE_CONFIG = {
    "vocab_size": 50257,    # Vocabulary size
    "context_length": 1024, # Context length
    "drop_rate": 0.0,       # Dropout rate
    "qkv_bias": True        # Query-key-value bias
}

model_configs = {
    "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
    "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
    "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
    "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
}


CHOOSE_MODEL = "gpt2-small (124M)"
BASE_CONFIG.update(model_configs[CHOOSE_MODEL])

### 下载文件

In [None]:
file_name = "gpt2-small-124M.pth"
# file_name = "gpt2-medium-355M.pth"
# file_name = "gpt2-large-774M.pth"
# file_name = "gpt2-xl-1558M.pth"

In [None]:
import os
import requests

url = f"https://huggingface.co/rasbt/gpt2-from-scratch-pytorch/resolve/main/{file_name}"

if not os.path.exists(file_name):
    response = requests.get(url, timeout=60)
    response.raise_for_status()
    with open(file_name, "wb") as f:
        f.write(response.content)
    print(f"Downloaded to {file_name}")

### 加载权重

In [None]:
import torch
from llms_from_scratch.ch04 import GPTModel
# For llms_from_scratch installation instructions, see:
# https://github.com/rasbt/LLMs-from-scratch/tree/main/pkg


gpt = GPTModel(BASE_CONFIG)
gpt.load_state_dict(torch.load(file_name, weights_only=True))
gpt.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gpt.to(device);

### 生成文本

In [None]:
import tiktoken
from llms_from_scratch.ch05 import generate, text_to_token_ids, token_ids_to_text


torch.manual_seed(123)

tokenizer = tiktoken.get_encoding("gpt2")

token_ids = generate(
    model=gpt.to(device),
    idx=text_to_token_ids("Every effort moves", tokenizer).to(device),
    max_new_tokens=30,
    context_size=BASE_CONFIG["context_length"],
    top_k=1,
    temperature=1.0
)

print("Output text:\n", token_ids_to_text(token_ids, tokenizer))

## 替代的safetensors文件

- 此外，[https://huggingface.co/rasbt/gpt2-from-scratch-pytorch](https://huggingface.co/rasbt/gpt2-from-scratch-pytorch) 仓库还包含`.safetensors`版本的状态字典
- `.safetensors`文件的吸引力在于其安全设计，它们仅存储张量数据，并在加载时避免执行潜在的恶意代码
- 在PyTorch的新版本（例如 2.0 及更高版本）中，可以使用`weights_only=True`参数与`torch.load`（例如 `torch.load("model_state_dict.pth", weights_only=True)` ）一起使用，通过跳过代码执行并仅加载权重来提高安全性（在PyTorch 2.6及更高版本中，此功能现已默认启用）；因此在这种情况下，从状态字典文件中加载权重不应再成为问题
- 然而，下面的代码块简要展示了如何从这些`.safetensor`文件中加载模型

In [None]:
file_name = "gpt2-small-124M.safetensors"
# file_name = "gpt2-medium-355M.safetensors"
# file_name = "gpt2-large-774M.safetensors"
# file_name = "gpt2-xl-1558M.safetensors"

In [None]:
import os
import urllib.request

url = f"https://huggingface.co/rasbt/gpt2-from-scratch-pytorch/resolve/main/{file_name}"

if not os.path.exists(file_name):
    urllib.request.urlretrieve(url, file_name)
    print(f"Downloaded to {file_name}")

In [None]:
# Load file

from safetensors.torch import load_file

gpt = GPTModel(BASE_CONFIG)
gpt.load_state_dict(load_file(file_name))
gpt.eval();

In [None]:
token_ids = generate(
    model=gpt.to(device),
    idx=text_to_token_ids("Every effort moves", tokenizer).to(device),
    max_new_tokens=30,
    context_size=BASE_CONFIG["context_length"],
    top_k=1,
    temperature=1.0
)

print("Output text:\n", token_ids_to_text(token_ids, tokenizer))