## 1. 导入相关库

In [3]:
import datasets
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
from transformers import AutoModelForCausalLM
from transformers import TrainingArguments, Seq2SeqTrainingArguments
from transformers import Trainer, Seq2SeqTrainer
import transformers
from transformers import DataCollatorWithPadding
from transformers import TextGenerationPipeline
import torch
import numpy as np
import os
import re
from tqdm import tqdm
import torch.nn as nn

## 2. 加载数据集

In [5]:
# 数据集名称
DATASET_NAME = "rotten_tomatoes"

# 加载数据集
raw_datasets = load_dataset(DATASET_NAME, cache_dir="/Volumes/WD_BLACK/data/rotten_tomatoes")

# 训练集
raw_train_dataset = raw_datasets["train"]

# 验证集
raw_valid_dataset = raw_datasets["validation"]

Downloading readme: 100%|██████████| 7.46k/7.46k [00:00<00:00, 6.56MB/s]
Downloading data: 100%|██████████| 699k/699k [00:10<00:00, 67.1kB/s]
Downloading data: 100%|██████████| 90.0k/90.0k [00:03<00:00, 28.9kB/s]
Downloading data: 100%|██████████| 92.2k/92.2k [00:03<00:00, 29.5kB/s]
Generating train split: 100%|██████████| 8530/8530 [00:00<00:00, 382290.42 examples/s]
Generating validation split: 100%|██████████| 1066/1066 [00:00<00:00, 398035.08 examples/s]
Generating test split: 100%|██████████| 1066/1066 [00:00<00:00, 488061.14 examples/s]


In [6]:
print(raw_train_dataset[0])

{'text': 'the rock is destined to be the 21st century\'s new " conan " and that he\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .', 'label': 1}


In [13]:
raw_train_dataset

Dataset({
    features: ['text', 'label'],
    num_rows: 8530
})

In [14]:
raw_train_dataset.column_names

['text', 'label']

## 3. 加载模型

In [7]:
# 模型名称
# MODEL_NAME = "gpt2"
MODEL_NAME_OR_PATH = "/Volumes/WD_BLACK/models/gpt2"

# 加载模型
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME_OR_PATH, trust_remote_code=True)  # trust_remote_code=True表示信任远程代码

## 4. 加载Tokenizer

In [8]:
# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, trust_remote_code=True)

# 在GPT2中没有pad_token，需要手动添加
tokenizer.add_special_tokens({'pad_token': '[PAD]'})  # 添加特殊token
tokenizer.pad_token_id = 0  # 设置pad_token_id

In [10]:
# 其它相关公共变量赋值

# 设置随机种子：同个种子的随机序列可复现
transformers.set_seed(42)

# 标签集
named_labels = ['neg', 'pos']

# 标签转 token_id
label_ids = [
    tokenizer(named_labels[i], add_special_tokens=False)["input_ids"][0]
    for i in range(len(named_labels))
]  # add_special_tokens=False表示不添加特殊token

## 5. 处理数据集
转成模型接受的输入格式
   - 拼接输入输出：\<INPUT TOKEN IDS\>\<EOS_TOKEN_ID\>\<OUTPUT TOKEN IDS\>
   - PAD 成相等长度：
     - <INPUT 1.1><INPUT 1.2>...\<EOS_TOKEN_ID\>\<OUTPUT TOKEN IDS\>\<PAD\>...\<PAD\>
     - <INPUT 2.1><INPUT 2.2>...\<EOS_TOKEN_ID\>\<OUTPUT TOKEN IDS\>\<PAD\>...\<PAD\>
   - 标识出参与 Loss 计算的 Tokens (只有输出 Token 参与 Loss 计算)
     - \<-100\>\<-100\>...\<OUTPUT TOKEN IDS\>\<-100\>...\<-100\>
     - 除了输出其他都标记为-100，是Huggingface预留的标记


In [11]:
MAX_LEN=32   #最大序列长度（输入+输出）
DATA_BODY_KEY = "text" # 数据集中的输入字段名
DATA_LABEL_KEY = "label" #数据集中输出字段名

# 定义数据处理函数，把原始数据转成input_ids, attention_mask, labels
def process_fn(examples):
    model_inputs = {
            "input_ids": [],
            "attention_mask": [],
            "labels": [],
        }
    for i in range(len(examples[DATA_BODY_KEY])):
        inputs = tokenizer(examples[DATA_BODY_KEY][i],add_special_tokens=False)
        label = label_ids[examples[DATA_LABEL_KEY][i]]
        input_ids = inputs["input_ids"] + [tokenizer.eos_token_id, label]

        raw_len = len(input_ids)
        input_len = len(inputs["input_ids"]) + 1 # 加1是因为eos_token_id

        if raw_len >= MAX_LEN:
            input_ids = input_ids[-MAX_LEN:]  # 当长度超过最大长度时，只取后面的最大长度
            attention_mask = [1] * MAX_LEN
            labels = [-100]*(MAX_LEN - 1) + [label]
        else:
            input_ids = input_ids + [tokenizer.pad_token_id] * (MAX_LEN - raw_len)
            attention_mask = [1] * raw_len + [0] * (MAX_LEN - raw_len)
            labels = [-100]*input_len + [label] + [-100] * (MAX_LEN - raw_len)
        model_inputs["input_ids"].append(input_ids)
        model_inputs["attention_mask"].append(attention_mask)
        model_inputs["labels"].append(labels)
    return model_inputs

In [15]:
# 处理训练数据集
tokenized_train_dataset = raw_train_dataset.map(
    process_fn,
    batched=True,
    remove_columns=raw_train_dataset.column_names,  # 已经对数据进行了处理，去除原有数据中的列，只保留处理后的列
    desc="Running tokenizer on train dataset",
)

# 处理验证数据集
tokenized_valid_dataset = raw_valid_dataset.map(
    process_fn,
    batched=True,
    remove_columns=raw_valid_dataset.column_names,
    desc="Running tokenizer on validation dataset",
)

Running tokenizer on train dataset: 100%|██████████| 8530/8530 [00:00<00:00, 11860.61 examples/s]
Running tokenizer on validation dataset: 100%|██████████| 1066/1066 [00:00<00:00, 11383.43 examples/s]


## 6. 定义数据规整器
训练时自动将数据拆分成 Batch

In [16]:
# 定义数据校准器（自动生成batch）
collater = DataCollatorWithPadding(
    tokenizer=tokenizer, 
    return_tensors="pt",
)

## 7. 定义训练超参

In [17]:
LR = 2e-5         # 学习率
BATCH_SIZE = 8    # Batch大小
INTERVAL = 100    # 每多少步打一次 log / 做一次 eval

# 定义训练参数
training_args = TrainingArguments(
    output_dir="./output",              # checkpoint保存路径
    evaluation_strategy="steps",        # 按步数计算eval频率
    overwrite_output_dir=True,
    num_train_epochs=1,                 # 训练epoch数
    per_device_train_batch_size=BATCH_SIZE,     # 每张卡的batch大小
    gradient_accumulation_steps=1,              # 累加几个step做一次参数更新
    per_device_eval_batch_size=BATCH_SIZE,      # evaluation batch size
    eval_steps=INTERVAL,                # 每N步eval一次
    logging_steps=INTERVAL,             # 每N步log一次
    save_steps=INTERVAL,                # 每N步保存一个checkpoint
    learning_rate=LR,                   # 学习率
)

## 8. 定义训练器

In [18]:
# 节省显存
model.gradient_checkpointing_enable()  # 开启梯度检查点，当反向传播时，重新计算中间激活值

# 定义训练器
trainer = Trainer(
    model=model,  # 待训练模型
    args=training_args,  # 训练参数
    data_collator=collater,  # 数据校准器
    train_dataset=tokenized_train_dataset,  # 训练集
    eval_dataset=tokenized_valid_dataset,   # 验证集
    # compute_metrics=compute_metric,         # 计算自定义评估指标
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


## 9. 开始训练

In [19]:
# 开始训练
trainer.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
[34m[1mwandb[0m: Currently logged in as: [33m315680524[0m ([33m550w[0m). Use [1m`wandb login --relogin`[0m to force relogin
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using

  0%|          | 0/1067 [00:00<?, ?it/s]You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
wandb: Network error (TransientError), entering retry loop.
  9%|▉         | 100/1067 [00:35<04:35,  3.51it/s]

{'loss': 0.0572, 'learning_rate': 1.8125585754451735e-05, 'epoch': 0.09}


                                                  
  9%|▉         | 100/1067 [00:42<04:35,  3.51it/s]

{'eval_loss': 0.021487252786755562, 'eval_runtime': 7.3399, 'eval_samples_per_second': 145.234, 'eval_steps_per_second': 18.256, 'epoch': 0.09}


 19%|█▊        | 200/1067 [01:14<04:28,  3.23it/s]

{'loss': 0.0234, 'learning_rate': 1.6251171508903468e-05, 'epoch': 0.19}


                                                  
 19%|█▊        | 200/1067 [01:21<04:28,  3.23it/s]

{'eval_loss': 0.017513608559966087, 'eval_runtime': 6.4411, 'eval_samples_per_second': 165.5, 'eval_steps_per_second': 20.804, 'epoch': 0.19}


 28%|██▊       | 300/1067 [01:53<03:41,  3.47it/s]

{'loss': 0.0172, 'learning_rate': 1.4376757263355203e-05, 'epoch': 0.28}


                                                  
 28%|██▊       | 300/1067 [02:00<03:41,  3.47it/s]

{'eval_loss': 0.012730359099805355, 'eval_runtime': 6.654, 'eval_samples_per_second': 160.204, 'eval_steps_per_second': 20.138, 'epoch': 0.28}


 37%|███▋      | 400/1067 [02:33<03:22,  3.29it/s]

{'loss': 0.0147, 'learning_rate': 1.2502343017806936e-05, 'epoch': 0.37}


                                                  
 37%|███▋      | 400/1067 [02:40<03:22,  3.29it/s]

{'eval_loss': 0.012395706959068775, 'eval_runtime': 6.9641, 'eval_samples_per_second': 153.07, 'eval_steps_per_second': 19.241, 'epoch': 0.37}


 47%|████▋     | 500/1067 [03:12<02:48,  3.37it/s]

{'loss': 0.0127, 'learning_rate': 1.0627928772258671e-05, 'epoch': 0.47}


                                                  
 47%|████▋     | 500/1067 [03:19<02:48,  3.37it/s]

{'eval_loss': 0.013758053071796894, 'eval_runtime': 6.9672, 'eval_samples_per_second': 153.002, 'eval_steps_per_second': 19.233, 'epoch': 0.47}


 56%|█████▌    | 600/1067 [03:52<02:14,  3.48it/s]

{'loss': 0.0151, 'learning_rate': 8.753514526710405e-06, 'epoch': 0.56}


                                                  
 56%|█████▌    | 600/1067 [03:58<02:14,  3.48it/s]

{'eval_loss': 0.011735978536307812, 'eval_runtime': 6.2788, 'eval_samples_per_second': 169.778, 'eval_steps_per_second': 21.342, 'epoch': 0.56}


 66%|██████▌   | 700/1067 [04:28<01:45,  3.48it/s]

{'loss': 0.0133, 'learning_rate': 6.879100281162138e-06, 'epoch': 0.66}


                                                  
 66%|██████▌   | 700/1067 [04:35<01:45,  3.48it/s]

{'eval_loss': 0.012164085172116756, 'eval_runtime': 6.2543, 'eval_samples_per_second': 170.443, 'eval_steps_per_second': 21.425, 'epoch': 0.66}


 75%|███████▍  | 800/1067 [05:05<01:14,  3.56it/s]

{'loss': 0.0133, 'learning_rate': 5.004686035613872e-06, 'epoch': 0.75}


                                                  
 75%|███████▍  | 800/1067 [05:11<01:14,  3.56it/s]

{'eval_loss': 0.011407976970076561, 'eval_runtime': 6.2414, 'eval_samples_per_second': 170.796, 'eval_steps_per_second': 21.47, 'epoch': 0.75}


 84%|████████▍ | 900/1067 [05:42<00:47,  3.52it/s]

{'loss': 0.0129, 'learning_rate': 3.1302717900656047e-06, 'epoch': 0.84}


                                                  
 84%|████████▍ | 900/1067 [05:48<00:47,  3.52it/s]

{'eval_loss': 0.011636830866336823, 'eval_runtime': 6.2606, 'eval_samples_per_second': 170.27, 'eval_steps_per_second': 21.404, 'epoch': 0.84}


 94%|█████████▎| 1000/1067 [06:20<00:21,  3.13it/s]

{'loss': 0.0121, 'learning_rate': 1.2558575445173386e-06, 'epoch': 0.94}


                                                   
 94%|█████████▎| 1000/1067 [06:28<00:21,  3.13it/s]

{'eval_loss': 0.011485500261187553, 'eval_runtime': 7.5901, 'eval_samples_per_second': 140.446, 'eval_steps_per_second': 17.655, 'epoch': 0.94}


100%|██████████| 1067/1067 [06:50<00:00,  2.60it/s]

{'train_runtime': 418.9792, 'train_samples_per_second': 20.359, 'train_steps_per_second': 2.547, 'train_loss': 0.018753886334563152, 'epoch': 1.0}





TrainOutput(global_step=1067, training_loss=0.018753886334563152, metrics={'train_runtime': 418.9792, 'train_samples_per_second': 20.359, 'train_steps_per_second': 2.547, 'train_loss': 0.018753886334563152, 'epoch': 1.0})

## 10. 加载训练后的模型进行推理（参考）

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

# 加载训练后的 checkpoint
model = AutoModelForCausalLM.from_pretrained("output/checkpoint-1000")

# 模型设为推理模式
model.eval()

# 加载 tokenizer
tokenizer = AutoTokenizer.from_pretrained("/Volumes/WD_BLACK/models/gpt2")

# 待分类文本
text = "This is a good movie!"

# 文本转 token ids - 记得以 eos 标识输入结束，与训练时一样
inputs = tokenizer(text+tokenizer.eos_token, return_tensors="pt")

# 推理：预测标签
output = model.generate(**inputs, do_sample=False, max_new_tokens=1)

# label token 转标签文本
tokenizer.decode(output[0][-1])

## 11. 加载 checkpoint 并继续训练（选）

In [None]:
trainer.train(resume_from_checkpoint="/path/to/checkpoint")

### 总结上述过程

1. 加载数据集
2. 数据预处理：
   - 将输入输出按特定格式拼接
   - 文本转 Token IDs
   - 通过 labels 标识出哪部分是输出（只有输出的 token 参与 loss 计算）
3. 加载模型、Tokenizer
4. 定义数据规整器
5. 定义训练超参：学习率、批次大小、...
6. 定义训练器
7. 开始训练
8. 注意：训练后推理时，输入数据的拼接方式要与训练时一致