<table style="width:100%">
<tr>
<td style="vertical-align:middle; text-align:left;">
<font size="2">
Supplementary code for the <a href="http://mng.bz/orYv">Build a Large Language Model From Scratch</a> book by <a href="https://sebastianraschka.com">Sebastian Raschka</a><br>
<br>Code repository: <a href="https://github.com/rasbt/LLMs-from-scratch">https://github.com/rasbt/LLMs-from-scratch</a>
</font>
</td>
<td style="vertical-align:middle; text-align:left;">
<a href="http://mng.bz/orYv"><img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp" width="100px"></a>
</td>
</tr>
</table>

# Load And Use Finetuned Model
# 加载和使用微调模型

This notebook contains minimal code to load the finetuned model that was created and saved in chapter 6 via [ch06.ipynb](ch06.ipynb).
本笔记本包含最小代码，用于加载在第6章通过[ch06.ipynb](ch06.ipynb)创建和保存的微调模型。

In [1]:
from importlib.metadata import version

pkgs = [
    "tiktoken",    # Tokenizer
    "torch",       # Deep learning library
]
for p in pkgs:
    print(f"{p} version: {version(p)}")

tiktoken version: 0.7.0
torch version: 2.4.0


In [2]:
# 导入Path类用于处理文件路径
from pathlib import Path

# 定义微调模型的保存路径
finetuned_model_path = Path("review_classifier.pth")
# 检查模型文件是否存在
if not finetuned_model_path.exists():
    print(
        f"Could not find '{finetuned_model_path}'.\n"
        "Run the `ch06.ipynb` notebook to finetune and save the finetuned model."
    )

In [3]:
# 从前面章节导入GPT模型
from previous_chapters import GPTModel


# 基础配置字典
BASE_CONFIG = {
    "vocab_size": 50257,     # 词汇表大小
    "context_length": 1024,  # 上下文长度
    "drop_rate": 0.0,        # Dropout比率
    "qkv_bias": True         # 查询-键-值偏置项
}

# 不同规模GPT2模型的配置
model_configs = {
    "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},    # 小型GPT2,1.24亿参数
    "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},  # 中型GPT2,3.55亿参数
    "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},   # 大型GPT2,7.74亿参数
    "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},     # 超大型GPT2,15.58亿参数
}

# 选择要使用的模型规模
CHOOSE_MODEL = "gpt2-small (124M)"

# 更新基础配置
BASE_CONFIG.update(model_configs[CHOOSE_MODEL])

# 初始化基础模型
model = GPTModel(BASE_CONFIG)

In [4]:
# 导入PyTorch库
import torch

# 将模型转换为分类器(参考ch06.ipynb第6.5节)
num_classes = 2  # 定义分类类别数(垃圾/非垃圾评论)
# 创建线性分类层,输入维度为模型嵌入维度,输出维度为类别数
model.out_head = torch.nn.Linear(in_features=BASE_CONFIG["emb_dim"], out_features=num_classes)

# 加载预训练权重
# 根据是否有GPU选择设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载模型权重到指定设备
model.load_state_dict(torch.load("review_classifier.pth", map_location=device, weights_only=True))
# 将模型移至指定设备
model.to(device)
# 将模型设置为评估模式
model.eval();

In [5]:
# 导入tiktoken库用于分词
import tiktoken

# 获取GPT-2的分词器
tokenizer = tiktoken.get_encoding("gpt2")

In [6]:
# 此函数在ch06.ipynb中实现
def classify_review(text, model, tokenizer, device, max_length=None, pad_token_id=50256):
    # 将模型设置为评估模式
    model.eval()

    # 准备模型输入:使用分词器对文本进行编码
    input_ids = tokenizer.encode(text)
    # 获取模型支持的最大上下文长度
    supported_context_length = model.pos_emb.weight.shape[0]

    # 如果序列太长则截断
    input_ids = input_ids[:min(max_length, supported_context_length)]

    # 用pad_token_id填充序列到指定长度
    input_ids += [pad_token_id] * (max_length - len(input_ids))
    # 转换为tensor并添加batch维度
    input_tensor = torch.tensor(input_ids, device=device).unsqueeze(0)

    # 模型推理:不计算梯度
    with torch.no_grad():
        # 获取最后一个token的logits输出
        logits = model(input_tensor.to(device))[:, -1, :]
    # 获取预测标签
    predicted_label = torch.argmax(logits, dim=-1).item()

    # 返回分类结果:1表示垃圾文本,0表示正常文本
    return "spam" if predicted_label == 1 else "not spam"

In [7]:
# 定义一个垃圾短信样本文本
text_1 = (
    "You are a winner you have been specially"
    " selected to receive $1000 cash or a $2000 award."
)

# 使用分类模型对文本进行分类,并打印结果
print(classify_review(
    text_1, model, tokenizer, device, max_length=120
))

spam


In [8]:
# 定义一个正常短信样本文本
text_2 = (
    "Hey, just wanted to check if we're still on"
    " for dinner tonight? Let me know!"
)

# 使用分类模型对文本进行分类,并打印结果 
print(classify_review(
    text_2, model, tokenizer, device, max_length=120
))

not spam
