# 使用 Trainer API 或者 Keras 微调一个模型

Install the Transformers, Datasets, and Evaluate libraries to run this notebook.

In [None]:
!pip install datasets evaluate transformers[sentencepiece]

### 训练

In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorWithPadding

raw_datasets = load_dataset("glue", "mrpc")
# checkpoint = "bert-base-uncased"
checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)


def tokenize_function(example):
    return tokenizer(example["sentence1"], 
                     example["sentence2"], 
                     truncation=True
                     )


tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) # 动态补齐

  from .autonotebook import tqdm as notebook_tqdm
Using the latest cached version of the module from C:\Users\pb078553\.cache\huggingface\modules\datasets_modules\datasets\glue\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad (last modified on Sat May  6 16:39:38 2023) since it couldn't be found locally at glue., or remotely on the Hugging Face Hub.
Found cached dataset glue (C:/Users/pb078553/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
100%|██████████| 3/3 [00:00<00:00, 111.26it/s]
Map:   0%|          | 0/3668 [00:00<?, ? examples/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
                                                                 

​ Transformers库提供了Trainer类来非常方便的进行预训练模型的微调。

（1）第一步定义超参数。

就是定义TrainingArgument对象，它会作为一个参数传入到Trainer对象中。该参数对象包括训练和验证所需的超参数。下面代码超参数只定义了输出目录，该目录下会记录训练过程中的数据和中间的检查点（包括模型权重）。这里只定义这一个参数就够了

In [2]:
from transformers import TrainingArguments

training_args = TrainingArguments("test-trainer")                                 )

(2) 第二步定义模型

In [3]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)

(3) 定义Trainer对象

In [4]:
from transformers import Trainer

trainer = Trainer(
    model,
    training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)

In [6]:
trainer.train()

  2%|▏         | 30/1377 [04:01<3:00:57,  8.06s/it]


KeyboardInterrupt: 

用Trainer对象的函数train()进行训练。上面代码运行后则开始微调过程，并且每隔500步会打印训练损失值。但是并不会进行验证过程，原因：

（1）我们没有设置验证策略，需要设置参数evaluation_strategy为steps（每隔eval_steps验证一次，这也是参数）或epoch（每轮验证一次）

（2）我们没有在Trainer中设置compute_metrics()函数，那么即使设置了evaluation_strategy参数，也只会打印损失值，而不会打印指标值（例如分类会打印交叉熵损失值，而不会打印准确率之类的）


### 验证或评估

In [None]:
predictions = trainer.predict(tokenized_datasets["validation"])
print(predictions.predictions.shape, predictions.label_ids.shape)

(408, 2) (408,)

predict()函数返回命名元组，有三个域：predictions，label_ids，metrics。metrics包含验证集的损失，有时候还有其它指标（如果设置了compute_metrics函数的话）。如果定义了compute_metric函数，该函数必须返回字典，且字典上的键值会打印在训练过程中

In [None]:
# 将logits转换为标签
import numpy as np

preds = np.argmax(predictions.predictions, axis=-1)

In [None]:
# 算指标
import evaluate

metric = evaluate.load("glue", "mrpc")
metric.compute(predictions=preds, references=predictions.label_ids)

{'accuracy': 0.8578431372549019, 'f1': 0.8996539792387542}

### 训练和验证整合到一起的写法

In [None]:
def compute_metrics(eval_preds):
    metric = evaluate.load("glue", "mrpc")
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [None]:
training_args = TrainingArguments("test-trainer", 
                                  evaluation_strategy="epoch" # 每轮训练结束时，计算验证数据集的准确率/F1
                                  )
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)

trainer = Trainer(
    model,
    training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)