# Hugging Face Transformers 微调训练入门

本示例将介绍基于 Transformers 实现模型微调训练的主要流程，包括：
- 数据集下载
- 数据预处理
- 训练超参数配置
- 训练评估指标设置
- 训练器基本介绍
- 实战训练
- 模型保存

## YelpReviewFull 数据集

**Hugging Face 数据集：[ YelpReviewFull ](https://huggingface.co/datasets/yelp_review_full)**

### 数据集摘要

Yelp评论数据集包括来自Yelp的评论。它是从Yelp Dataset Challenge 2015数据中提取的。

### 支持的任务和排行榜
文本分类、情感分类：该数据集主要用于文本分类：给定文本，预测情感。

### 语言
这些评论主要以英语编写。

### 数据集结构

#### 数据实例
一个典型的数据点包括文本和相应的标签。

来自YelpReviewFull测试集的示例如下：

```json
{
    'label': 0,
    'text': 'I got \'new\' tires from them and within two weeks got a flat. I took my car to a local mechanic to see if i could get the hole patched, but they said the reason I had a flat was because the previous patch had blown - WAIT, WHAT? I just got the tire and never needed to have it patched? This was supposed to be a new tire. \\nI took the tire over to Flynn\'s and they told me that someone punctured my tire, then tried to patch it. So there are resentful tire slashers? I find that very unlikely. After arguing with the guy and telling him that his logic was far fetched he said he\'d give me a new tire \\"this time\\". \\nI will never go back to Flynn\'s b/c of the way this guy treated me and the simple fact that they gave me a used tire!'
}
```

#### 数据字段

- 'text': 评论文本使用双引号（"）转义，任何内部双引号都通过2个双引号（""）转义。换行符使用反斜杠后跟一个 "n" 字符转义，即 "\n"。
- 'label': 对应于评论的分数（介于1和5之间）。

#### 数据拆分

Yelp评论完整星级数据集是通过随机选取每个1到5星评论的130,000个训练样本和10,000个测试样本构建的。总共有650,000个训练样本和50,000个测试样本。

## 下载数据集

In [None]:
pip install datasets

Collecting datasets
  Downloading datasets-2.15.0-py3-none-any.whl (521 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m521.2/521.2 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow-hotfix (from datasets)
  Downloading pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m12.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: pyarrow-hotfix, dill, multiprocess, datasets
Successfully installed datasets-2.15.0 dill-0.3.7 multiprocess-0.70.15 pyarrow-hotfix-0.6


In [None]:
from datasets import load_dataset

dataset = load_dataset("yelp_review_full")

In [None]:
dataset

DatasetDict({
    train: Dataset({
        features: ['label', 'text'],
        num_rows: 650000
    })
    test: Dataset({
        features: ['label', 'text'],
        num_rows: 50000
    })
})

In [None]:
dataset["train"][10]

{'label': 0,
 'text': "Owning a driving range inside the city limits is like a license to print money.  I don't think I ask much out of a driving range.  Decent mats, clean balls and accessible hours.  Hell you need even less people now with the advent of the machine that doles out the balls.  This place has none of them.  It is april and there are no grass tees yet.  BTW they opened for the season this week although it has been golfing weather for a month.  The mats look like the carpet at my 107 year old aunt Irene's house.  Worn and thread bare.  Let's talk about the hours.  This place is equipped with lights yet they only sell buckets of balls until 730.  It is still light out.  Finally lets you have the pit to hit into.  When I arrived I wasn't sure if this was a driving range or an excavation site for a mastodon or a strip mining operation.  There is no grass on the range. Just mud.  Makes it a good tool to figure out how far you actually are hitting the ball.  Oh, they are cash 

In [None]:
import random
import pandas as pd
import datasets
from IPython.display import display, HTML

In [None]:
def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)

    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

In [None]:
show_random_elements(dataset["train"])

Unnamed: 0,label,text
0,3 stars,"Ok again I'm going against the grain and sticking to what I know, good southern bbq! This place is ok but far from the mark. The beef ribs were too fatty, the pork ribs weren't fall off the bone and skimpy. The flavor was pure smoke only. No real sense of the bbq flavor. The yams were drenched in liquid sauce and couldn't mix it in so it ruined it for me. And sorry to say as I was excited for cobbler, but this is far from cobbler! Gosh this is making a trip down south seem overdue! \nWe called ahead our order and still waited a long time as they just said overwhelmed by orders. The little one wouldn't eat the ribs she was so excited for either. I so wanted to like this place so I can have a bbq place in town based on a lot of people's reviews but needless to say I'll keep on looking...."
1,3 stars,"I'm a pedicure junkie - I like to get my toes done at least once a month, I don't care if it's too cold out for anyone to ever see them but me! So, a couple of months ago, I moved to a new part of town. Classic Nails is up the street from my new place (along with about 600 other salons), and I decided to pop in...\n\nPros:\n1. Reasonably priced (I think my spa pedi was $25+tip)\n2. Comfy massage chairs\n3. No wait\n4. Brought me a bottle of water\n5. Trashy magazines to read (ie \""People,\"" not \""Hustler\"")\n6. Toes looked great at the end\n\nCons:\n1.The lady hurt me a few times with the cuticle trimmers and the slougher. When I exclaimed, \""Ow!' for the third time, she said, \""You have sensitive feet!' Yeah, sensitive to being cut open!\n\n2. My pedicure started chipping a bit after a week. And normally I can go about three weeks before anything even starts to get sad.\n\nOverall, it was a decent place - but I will probably keep looking. There are too many nail salons around to settle!\n\nLast week I had a girl friend in town, and I took her to my old place, T Nail Spa, @ 4616 W. Sahara. It's further away, but at least I knew I'd get a good pedi and no pain!"
2,1 star,"Thought I'd give this place a try since I live close, won't be back...\nLittle old lady that did my pedi was nice but don't go here if you're looking for a relaxing, pampering experience. She kept sucking leftover food through her teeth and the \""massage\"" was anything but that. They should call it a lotion application because that's exactly what it was. And I know most people go for pedicures for the massage part. Come on, you all know it's true!\nAfter she was done she didn't say a word and just left and sat down to her phone. Then another nail tech finished hers and sat right next to her, took her shoes off and put them up on a chair next to her! I couldn't believe it, so I had to take a picture which a I have posted it.\nFinally, I started to get up and she came over to put my shoes on. Once she told me my total (which I will say, their prices are cheap) I asked her to add $5 to the total. She pointed at my purse and said \""You have cash for a tip?\"" I love cash too but really? You actually asked your customers that?? If you're wanting a quick, inexpensive, not relaxing pedicure this is your place. Otherwise, keeping looking."
3,2 star,"i was super excited to try quiet storm when i moved to pittsburgh because all of the reviews i read were fantastic. however, i have visited multiple times now continuously hoping for improvement and each time i've been disappointed. let's start with the good things. QS has a great atmosphere, it definitely helped the transition from austin. very chill \""hippie\"" vibe. however, the servers although friendly can sometimes be a touch \""hippie-elitest\"". the food was alright. their homefries were quite tasty but little else has tantalized my taste buds and i have tried many things. i do love their thai dressing, very yummy on the salads. will i continue to go back? probably, wishful thinking is a good thing."
4,1 star,"If I could give a zero star rating I would, while the hotel was beautiful, the timeshare people were nasty. We didn't stay here but got roped into a 2-3 hour timeshare presentation in exchange for free show tickets. We told them right off the bat we were interested in owning a timeshare but did not have the money to pay that day. They were rude, insulting and desperate for business. 6 hours later we finally got to leave with our \""sold out\"" show tickets in hand, where we had to rush over to MGM to reserve our seats so we didn't \""miss\"" the show (the theater was less than half full) It was the most uncomfortable scam I've ever endured. I would NEVER stay at this hotel or any of it's affiliates simply because of this experience."
5,5 stars,"I love this place! I've been to many frozen yogurt shops (in California, Hawaii and Arizona) and this one is my favorite.\n\nThere are tons of toppings and \""bottomings\"" to choose from, although they could offer more fruit choices. The quality/texture of the frozen yogurt is very good. I love their salted caramel flavor, especially when swirled with the sour apple (those two flavors need to be offered all the time!). \n\nThey have a frequent buyer program and I like that you don't have to carry around a card but simply have to tell them your name upon check out.\n\nThe girls working there seem friendly and are very good about providing samples, cleaning off the tables, and emptying the trash."
6,2 star,"I agree with Sean, as a Chicago native every good 'street food' type joint has the hot dog statue outside and the sports posters and stuff lining the walls.\nToday saw me ordering a gyro and I really would not recommend this as a gyro place. The pita was good enough, but the meat was weird-dry and VERY lemony...tzatziki would hold it's own I suppose but the meat totally took away from the experience. Seven bucks for a gyro and small pop. Take that for what it's worth..........\nI've eaten lunch here previously and I thought the beef sandwiches were okay, not great by any means but I'd order a beef sandwich way before the gyros should I lose my mind and patronize this establishment again."
7,4 stars,"I love this place and we go here often on our lunch break because service is fast, fresh, and friendly. \n\nA couple of must have items for me:\nChorizo, eggs, potato, and cheese breakfast burrito\nChicken taco salad with the hard shell. The cilantro lime dressing is amazing!\nThe Ceviche here is one of the best I have had."
8,2 star,First time here and all the reviews are true. Pushy and rude. Mediocre wash. Just go in with that attitude.
9,2 star,"Getting your hair cut here is worth the money if you don't plan on getting your hair cut every 6 weeks like you are supposed to.\n\nHowever, do not get your nails done here. We called to make an appointment for gels. They booked us. We arrived and they told us we were going to get a complimentary hot wax dip, which I started and as they started to have me pick out the color of nail polish we started to clarify that this was gel nails (because gels don't have nail polish, they are clear). As one of my hands had already been dipped they told us that actually they don't offer the gel service there and there must have been some confusion. We should have left then, but since I had already dipped my hand we decided to stay and get mani-s. They CHARGED me for the wax dip and our nails didn't last two days."


## 预处理数据

下载数据集到本地后，使用 Tokenizer 来处理文本，对于长度不等的输入数据，可以使用填充（padding）和截断（truncation）策略来处理。

Datasets 的 `map` 方法，支持一次性在整个数据集上应用预处理函数。

下面使用填充到最大长度的策略，处理整个数据集：

In [None]:
pip install transformers

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")


def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)


tokenized_datasets = dataset.map(tokenize_function, batched=True)

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

In [None]:
show_random_elements(tokenized_datasets["train"], num_examples=1)

Unnamed: 0,label,text,input_ids,token_type_ids,attention_mask
0,1 star,So last night I was starving and came down from my room at the Monte Carlo and decided to order some food to go back to my room and was extremely disappointed. How ever while I was waiting the bartender was extremely knowledgeable about all the beers they had and was very friendly. I ordered the steak quesadilla and it was absolutely disgusting. It was like someone dumped a can of shredded roast beef on top of government cheese. It was extremely greasy and tasted terrible. After two bites of salty mushy what ever questionable meat I was eating I spit it out and went back down stairs of my hotel Lobby and got a subway sandwich. Hopefully the rest of the food there is better than that. But when you have STEAK written down it should be just that not what looked and tasted like pureed canned roast beef! I can only imagine what the chicken quesadilla taste like. Gross!,"[101, 1573, 1314, 1480, 146, 1108, 20285, 1105, 1338, 1205, 1121, 1139, 1395, 1120, 1103, 10046, 9503, 1105, 1879, 1106, 1546, 1199, 2094, 1106, 1301, 1171, 1106, 1139, 1395, 1105, 1108, 4450, 9333, 119, 1731, 1518, 1229, 146, 1108, 2613, 1103, 18343, 1108, 4450, 3044, 1895, 1164, 1155, 1103, 23147, 1152, 1125, 1105, 1108, 1304, 4931, 119, 146, 2802, 1103, 26704, 15027, 23417, 5878, 1105, 1122, 1108, 7284, 22852, 119, 1135, 1108, 1176, 1800, 14632, 170, 1169, 1104, 188, 8167, 23372, 1174, 187, 20219, 14413, 1113, 1499, 1104, 1433, 9553, 119, 1135, 1108, 4450, 176, 11811, 5821, 1105, 12876, 6434, ...]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]"


### 数据抽样

使用 1000 个数据样本，在 BERT 上演示小规模训练（基于 Pytorch Trainer）

`shuffle()`函数会随机重新排列列的值。如果您希望对用于洗牌数据集的算法有更多控制，可以在此函数中指定generator参数来使用不同的numpy.random.Generator。

In [None]:
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))

## 微调训练配置

### 加载 BERT 模型

警告通知我们正在丢弃一些权重（`vocab_transform` 和 `vocab_layer_norm` 层），并随机初始化其他一些权重（`pre_classifier` 和 `classifier` 层）。在微调模型情况下是绝对正常的，因为我们正在删除用于预训练模型的掩码语言建模任务的头部，并用一个新的头部替换它，对于这个新头部，我们没有预训练的权重，所以库会警告我们在用它进行推理之前应该对这个模型进行微调，而这正是我们要做的事情。

In [None]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### 训练超参数（TrainingArguments）

完整配置参数与默认值：https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/trainer#transformers.TrainingArguments

源代码定义：https://github.com/huggingface/transformers/blob/v4.36.1/src/transformers/training_args.py#L161

**最重要配置：模型权重保存路径(output_dir)**

In [None]:
!pip install transformers[torch]

Collecting accelerate>=0.20.3 (from transformers[torch])
  Downloading accelerate-0.25.0-py3-none-any.whl (265 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m265.7/265.7 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: accelerate
Successfully installed accelerate-0.25.0


In [None]:
!pip install accelerate -U



In [None]:
!pip list

Package                          Version
-------------------------------- ---------------------
absl-py                          1.4.0
accelerate                       0.25.0
aiohttp                          3.9.1
aiosignal                        1.3.1
alabaster                        0.7.13
albumentations                   1.3.1
altair                           4.2.2
anyio                            3.7.1
appdirs                          1.4.4
argon2-cffi                      23.1.0
argon2-cffi-bindings             21.2.0
array-record                     0.5.0
arviz                            0.15.1
astropy                          5.3.4
astunparse                       1.6.3
async-timeout                    4.0.3
atpublic                         4.0
attrs                            23.1.0
audioread                        3.0.1
autograd                         1.6.2
Babel                            2.14.0
backcall                         0.2.0
beautifulsoup4                   4.11.2
b

In [None]:
!pip install pytorch

Collecting pytorch
  Downloading pytorch-1.0.2.tar.gz (689 bytes)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pytorch
  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mpython setup.py bdist_wheel[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m See above for output.
  
  [1;35mnote[0m: This error originates from a subprocess, and is likely not a problem with pip.
  Building wheel for pytorch (setup.py) ... [?25lerror
[31m  ERROR: Failed building wheel for pytorch[0m[31m
[0m[?25h  Running setup.py clean for pytorch
Failed to build pytorch
[31mERROR: Could not build wheels for pytorch, which is required to install pyproject.toml-based projects[0m[31m
[0m

In [None]:
model

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [None]:
!ls -l /usr/local/lib/python3.10/dist-packages/transformers/models/bert

total 344
-rw-r--r-- 1 root root 10150 Dec 19 14:33 configuration_bert.py
-rw-r--r-- 1 root root 10490 Dec 19 14:33 convert_bert_original_tf2_checkpoint_to_pytorch.py
-rw-r--r-- 1 root root  2159 Dec 19 14:33 convert_bert_original_tf_checkpoint_to_pytorch.py
-rw-r--r-- 1 root root  4098 Dec 19 14:33 convert_bert_pytorch_checkpoint_to_original_tf.py
-rw-r--r-- 1 root root  7606 Dec 19 14:33 convert_bert_token_dropping_original_tf2_checkpoint_to_pytorch.py
-rw-r--r-- 1 root root  6057 Dec 19 14:33 __init__.py
-rw-r--r-- 1 root root 84063 Dec 19 14:33 modeling_bert.py
-rw-r--r-- 1 root root 63600 Dec 19 14:33 modeling_flax_bert.py
-rw-r--r-- 1 root root 85690 Dec 19 14:33 modeling_tf_bert.py
drwxr-xr-x 2 root root  4096 Dec 19 14:33 __pycache__
-rw-r--r-- 1 root root 14883 Dec 19 14:33 tokenization_bert_fast.py
-rw-r--r-- 1 root root 25175 Dec 19 14:33 tokenization_bert.py
-rw-r--r-- 1 root root 11757 Dec 19 14:33 tokenization_bert_tf.py


In [None]:
from transformers import TrainingArguments

model_dir = "models/bert-base-cased"

# logging_steps 默认值为500，根据我们的训练数据和步长，将其设置为100
training_args = TrainingArguments(output_dir=f"{model_dir}/test_trainer",
                                  logging_dir=f"{model_dir}/test_trainer/runs",
                                  logging_steps=100)

In [None]:
# 完整的超参数配置
print(training_args)

TrainingArguments(
_n_gpu=1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_pin_memory=True,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
dispatch_batches=None,
do_eval=False,
do_predict=False,
do_train=False,
eval_accumulation_steps=None,
eval_delay=0,
eval_steps=None,
evaluation_strategy=no,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
fsdp=[],
fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},
fsdp_min_num_params=0,
fsdp_transformer_layer_cls_to_wrap=None,
full_determinism=False,
gradient_accumulation_steps=1,
gradient_checkpointing=False,
gradient_checkpointing_kwargs=None,
greater_is_better=None,
group_by_length=False,
half_precision_backend=au

### 训练过程中的指标评估（Evaluate)

**[Hugging Face Evaluate 库](https://huggingface.co/docs/evaluate/index)** 支持使用一行代码，获得数十种不同领域（自然语言处理、计算机视觉、强化学习等）的评估方法。 当前支持 **完整评估指标：https://huggingface.co/evaluate-metric**

训练器（Trainer）在训练过程中不会自动评估模型性能。因此，我们需要向训练器传递一个函数来计算和报告指标。

Evaluate库提供了一个简单的准确率函数，您可以使用`evaluate.load`函数加载

In [None]:
pip install numpy evaluate

Collecting evaluate
  Downloading evaluate-0.4.1-py3-none-any.whl (84 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
Collecting responses<0.19 (from evaluate)
  Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Installing collected packages: responses, evaluate
Successfully installed evaluate-0.4.1 responses-0.18.0


In [None]:
import numpy as np
import evaluate

metric = evaluate.load("accuracy")


接着，调用 `compute` 函数来计算预测的准确率。

在将预测传递给 compute 函数之前，我们需要将 logits 转换为预测值（**所有Transformers 模型都返回 logits**）。

In [None]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

#### 训练过程指标监控

通常，为了监控训练过程中的评估指标变化，我们可以在`TrainingArguments`指定`evaluation_strategy`参数，以便在 epoch 结束时报告评估指标。

In [None]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(output_dir=f"{model_dir}/test_trainer",
                                  evaluation_strategy="epoch",
                                  logging_dir=f"{model_dir}/test_trainer/runs",
                                  logging_steps=100)

## 开始训练

### 实例化训练器（Trainer）

`kernel version` 版本问题：暂不影响本示例代码运行

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
)

## 使用 nvidia-smi 查看 GPU 使用

为了实时查看GPU使用情况，可以使用 `watch` 指令实现轮询：`watch -n 1 nvidia-smi`:

```shell
Every 1.0s: nvidia-smi                                                   Wed Dec 20 14:37:41 2023

Wed Dec 20 14:37:41 2023
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  Tesla T4                       Off | 00000000:00:0D.0 Off |                    0 |
| N/A   64C    P0              69W /  70W |   6665MiB / 15360MiB |     98%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A     18395      C   /root/miniconda3/bin/python                6660MiB |
+---------------------------------------------------------------------------------------+
```

In [None]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,1.3822,1.097836,0.509
2,1.0397,1.053801,0.556
3,0.718,1.058087,0.603


TrainOutput(global_step=375, training_loss=0.9595832722981771, metrics={'train_runtime': 384.6745, 'train_samples_per_second': 7.799, 'train_steps_per_second': 0.975, 'total_flos': 789354427392000.0, 'train_loss': 0.9595832722981771, 'epoch': 3.0})

In [None]:
small_test_dataset = tokenized_datasets["test"].shuffle(seed=64).select(range(100))

In [None]:
trainer.evaluate(small_test_dataset)

{'eval_loss': 1.1512675285339355,
 'eval_accuracy': 0.51,
 'eval_runtime': 3.5107,
 'eval_samples_per_second': 28.484,
 'eval_steps_per_second': 3.703,
 'epoch': 3.0}

### 保存模型和训练状态

- 使用 `trainer.save_model` 方法保存模型，后续可以通过 from_pretrained() 方法重新加载
- 使用 `trainer.save_state` 方法保存训练状态

In [None]:
trainer.save_model(f"{model_dir}/finetuned-trainer")

In [None]:
trainer.save_state()

## Homework: 使用完整的 YelpReviewFull 数据集训练，对比看 Acc 最高能到多少

## 保存文件到Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pwd

/content


In [None]:
 !cp -r models /content/drive/MyDrive/AI/FineTune

In [None]:
!ls -l /content/drive/MyDrive/AI/FineTune/models/bert-base-cased

total 8
drwx------ 2 root root 4096 Dec 21 05:11 finetuned-trainer
drwx------ 3 root root 4096 Dec 21 05:11 test_trainer


In [None]:
!ls -l -h /content/models/bert-base-cased/finetuned-trainer

total 414M
-rw-r--r-- 1 root root  955 Dec 21 05:06 config.json
-rw-r--r-- 1 root root 414M Dec 21 05:06 model.safetensors
-rw-r--r-- 1 root root 4.5K Dec 21 05:06 training_args.bin


In [None]:
!ls -l -h /content/models/bert-base-cased/test_trainer/runs

total 12K
-rw-r--r-- 1 root root 6.1K Dec 21 05:06 events.out.tfevents.1703134803.ec665cbc0fa0.4954.0
-rw-r--r-- 1 root root  411 Dec 21 05:06 events.out.tfevents.1703135191.ec665cbc0fa0.4954.1


In [None]:
!ls -l -h /content/models/bert-base-cased/test_trainer

total 8.0K
drwxr-xr-x 2 root root 4.0K Dec 21 05:06 runs
-rw-r--r-- 1 root root 1.9K Dec 21 05:06 trainer_state.json


In [None]:
!du -h /content/models/

414M	/content/models/bert-base-cased/finetuned-trainer
16K	/content/models/bert-base-cased/test_trainer/runs
24K	/content/models/bert-base-cased/test_trainer
414M	/content/models/bert-base-cased
414M	/content/models/
