# Hugging Face Transformers 微调训练入门

本示例将介绍基于 Transformers 实现模型微调训练的主要流程，包括：
- 数据集下载
- 数据预处理
- 训练超参数配置
- 训练评估指标设置
- 训练器基本介绍
- 实战训练
- 模型保存

## YelpReviewFull 数据集

**Hugging Face 数据集：[ YelpReviewFull ](https://huggingface.co/datasets/yelp_review_full)**

### 数据集摘要

Yelp评论数据集包括来自Yelp的评论。它是从Yelp Dataset Challenge 2015数据中提取的。

### 支持的任务和排行榜
文本分类、情感分类：该数据集主要用于文本分类：给定文本，预测情感。

### 语言
这些评论主要以英语编写。

### 数据集结构

#### 数据实例
一个典型的数据点包括文本和相应的标签。

来自YelpReviewFull测试集的示例如下：

```json
{
    'label': 0,
    'text': 'I got \'new\' tires from them and within two weeks got a flat. I took my car to a local mechanic to see if i could get the hole patched, but they said the reason I had a flat was because the previous patch had blown - WAIT, WHAT? I just got the tire and never needed to have it patched? This was supposed to be a new tire. \\nI took the tire over to Flynn\'s and they told me that someone punctured my tire, then tried to patch it. So there are resentful tire slashers? I find that very unlikely. After arguing with the guy and telling him that his logic was far fetched he said he\'d give me a new tire \\"this time\\". \\nI will never go back to Flynn\'s b/c of the way this guy treated me and the simple fact that they gave me a used tire!'
}
```

#### 数据字段

- 'text': 评论文本使用双引号（"）转义，任何内部双引号都通过2个双引号（""）转义。换行符使用反斜杠后跟一个 "n" 字符转义，即 "\n"。
- 'label': 对应于评论的分数（介于1和5之间）。

#### 数据拆分

Yelp评论完整星级数据集是通过随机选取每个1到5星评论的130,000个训练样本和10,000个测试样本构建的。总共有650,000个训练样本和50,000个测试样本。

## 下载数据集

In [4]:
from datasets import load_dataset

# dataset = load_dataset("yelp_review_full")
dataset = load_dataset("/home/rr-ai/huggingface_datasets/yelp_review_full/")

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [5]:
dataset

DatasetDict({
    train: Dataset({
        features: ['label', 'text'],
        num_rows: 650000
    })
    test: Dataset({
        features: ['label', 'text'],
        num_rows: 50000
    })
})

In [6]:
dataset["train"][111]

{'label': 2,
 'text': "As far as Starbucks go, this is a pretty nice one.  The baristas are friendly and while I was here, a lot of regulars must have come in, because they bantered away with almost everyone.  The bathroom was clean and well maintained and the trash wasn't overflowing in the canisters around the store.  The pastries looked fresh, but I didn't partake.  The noise level was also at a nice working level - not too loud, music just barely audible.\\n\\nI do wish there was more seating.  It is nice that this location has a counter at the end of the bar for sole workers, but it doesn't replace more tables.  I'm sure this isn't as much of a problem in the summer when there's the space outside.\\n\\nThere was a treat receipt promo going on, but the barista didn't tell me about it, which I found odd.  Usually when they have promos like that going on, they ask everyone if they want their receipt to come back later in the day to claim whatever the offer is.  Today it was one of th

In [7]:
import random
import pandas as pd
import datasets
from IPython.display import display, HTML

In [8]:
def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

In [9]:
show_random_elements(dataset["train"])

Unnamed: 0,label,text
0,1 star,"Gone down hill big time\nI have gone to this restaurant many times in my life. In the 60's it was a tavern and bait shop that had boat rental. I had not been to the Nau-Ti-Gal for years, and when my daughter and her BF came up from FL for Christmas we decided to take them there as it is a unique location and I in the past the food was pretty good. Saturday 12/27/14 we got there at 6:30 and the place was dead, which I thought a little odd. My wife ordered the seafood salad, my daughter, her BF and I ordered the Prime Rib special of the night, each meal came with a cup of clam chowder. My wifes seafood salad consided of canned baby shrimp, and canned crab meat on lettuce and tomatoes. The Prime Rib was a dark red color which I later realized was a Sam's pre-cooked Prime, I've had before. Later that night I got a gut ache, which turned into several trips to the bathroom, my wife got it the next morning, I'm thinking it was the chowder. \nI won't be going back to the Nau-Ti-Gal and would not recommend to anyone else..."
1,3 stars,"Agree with other reviewers that for an outlet mall, this has a vast selection of merchants, from the incredible expensive to the people selling SlapChops. If you are looking for a specific item, you are probably SOL. Just dive in, drop into stores that you like, and come home with a bunch of useless crap.Highlight of my recent trip was the David Yerman and Converse outlets. Drop by if you are bored of gambling and partying."
2,4 stars,"Room kinda small but nice..Speedy internet...Pool area well above average with hot tub,suana,and main pool with comfortable[ie: not cold] water..Buy one get one in lobby lounge,along with very tasty complimentary pizza..Continental breakfast and nice noon checkout time..All in all a very comfortable place,and the price IS right!"
3,5 stars,The best Jalape\u00f1o burger in town! Deep fried Jalape\u00f1os a half pound burger that is cooked to perfection!\n\nI have had it three times and each time it was awesome!
4,3 stars,"I only stayed here one night on my way into town. It was a last minute decision and we booked online for a super good deal. \n\nI arrived around 3 am and it took a long time to check in! I was surprised considering only 1 person was ahead of me in line... why the long check in? Even though we were only staying one night and it was super late they still made me pay the 18 dollar plus tax resort fee. So an extra 20 dollars for mediocre ammenties... not worth it to me. He upgraded us to an excutive suite as a consolitation even though we didn't ask for anything. I appreciated the gesture. \n\nAs we were walking up to our room we noticed these orange cones set up, we were carrying heavy duffle bags and one accidently was knocked over, we were literally shouted at by a hotel employee to \""PICK IT UP.\"" I was startled by their attitude and my lack of sleep.\n\nThe room itself was average. The bathrooms were tiny and so was the room. Only one mirror in the entire room in the bathroom of course! If this was an upgrade I wondered what my original room would have been! But the beds were definitely comfortable! I was just there to sleep so this made the night peaceful. \n\nI think the location on the strip is perfect. If you want to use the food court its not inside the hotel...you have to walk outside adjacent directly to Harrahs. It took us awhile to find it."
5,3 stars,"The Morning Squeeze is very laid back and cool atmosphere. Our waitress was at our table quickly and got our drink order. We sat outside and the weather in the morning was just fine and the outdoor misters helped :)\n\nWhen our order arrived...we got the Classic Bennie (Should of got that) and the Jet setter. It was prepared fine, except my bread was a little burnt..Check the picture out.\n\nThe staff was attentive, the food was pretty good and the atmosphere was great. A good breakfast and the prices are decent as well."
6,3 stars,"I had an issue recently that was not handled well. My apartment flooded twice. The first time was due to a part breaking in my washer, the second was because the part they ordered to fix it was \""faulty\"". They replaced the washer after 6-7 calendar days. But only offered to have my carpet cleaned even though this is the second time and that carpet was ruined and smelled. This is one of a few incidents that I've had. There communication with you during something like this is crucial and they definitely dropped the ball. The office staff said I was being hostile and rude but what they didn't understand or help me with was communicating what was happening when and how long things would take. If they had to live in a damp apartment with fans blowing and a two year old wanting to play they would understand. So, this being said, just hope you don't have issues that need attention and don't get smokers below you who smoke inside. I loved this place pretty much for the other 2 1/2 yrs. after finally dealing with Monica the area manager it was all resolved and taken care of. Sucks they have to get approval for everything too because the company is in CA."
7,3 stars,After hearing so much about this place I had to try it. I had split the Earl Club and the Montagu sandwich and I think the Montagu was the best of the 2. On both the meat was warm and moist and the bread was also really good. But I wish the sandwiches had more taste. \n\nMy friend said that the Chipotle chicken was really good so maybe get that one if you want a more tasty sandwich
8,2 star,"I would rate it one star if it wasn't for our salads being good.\n\nA coworker and I ordered two salads for pick up, as we were on our lunch break and didn't want to waste a lot of time. When I asked the guy who answered the phone for two small salads, he asked if I wanted a \""single\"" or \""small.\"" After asking him what the difference was, he told me that the small served 4-6 people and the single was good for one person... Hmm... I guess it would be a single then! From the website, it looked like there were only three options (S/M/L), but whatever.\n\nWhen we showed up to pick up our order, it was pretty clear that they were short on staff. There were 2 people working and one of them was behind the bar pretty much the whole entire time. They were definitely NOT in a hurry even though the restaurant had a good amount of people there... Just a couple of dudes putzing around! \n\nAnyways, he ran our cards and returned with 2 piddly, TINY boxes of salads.. What? After sitting in my car and wondering why we were charged $7 for a snack of a salad, we decided to go back in and see what the deal was. After we told him that it wasn't explained to us what a \""single\"" was, he admitted that he wasn't being exactly clear. \""Yeah, it was probably my fault.\"" Uh, yeah it was! That was all he said, no apology or anything. Now, we are not the type of people who go around looking for freebies, but you would think he would have been a little more... I don't know... Courteous! He ended up running our cards twice, since we wanted a larger order and disappearing in the back for almost half an hour. We would have left if he hadn't charged us, but we were stuck waiting... For...Ev....ER!\n\nWhile we were standing there twiddling our thumbs, it was obvious that some of the other customers felt just as neglected and frustrated. One woman even had to get up out of her seat and walk over to one of them to have something taken care of. Not cool. Finally, he brought us our food and still gave no sign that he even cared about what had happened. It's a cute place and the food might be decent, but I would not go back to this location!"
9,4 stars,"Wifee says ' I want Thai '. Ever obedient, I grab the iPad, load the Yelp app and search for Thai. Archi's rates highly, but there's three Archi's I note. Eeny,Meeny,Miny,Mo. Archi's Thai Kitchen it is. Jump in the car with wifee, teen daughter & fussy eating 10 year old. \nIt's easy to find and great parking. Two cop cars outside, I park right by them thinking they were watching the neighborhood and the car will be safe. No cops in the car! Hum? \nGo into Archi's and law enforcement were chilling at a table and seemed to be enjoying a plate of Thai food. The restaurant is bright, not large and reminds me of where you would eat your breakfast in a Hoilday Inn Express. I don't think that's a bad thing but it was very bright. Friendly server sits us immediately and gives us iced water. I order beer & Chardonnay as is my tradition. WHAT? NO BEER OR WINE? I sob..... It's clearly marked on the Yelp review but I didn't notice. Dumb. Anyway the food; we order Sesame Toast and egg rolls to start. The toast is deep fried and too greasy, egg rolls were fine. 10 year old McDonalds aficionado wants plain chicken , they fry him up plain chicken and it was great. Teen wants chicken fried rice and a Mongolian beef dish, ordered 3/10 on the heat/spice scale she thoroughly enjoyed it. Wifee ordered the same, but being a spicy food bigshot, she went for 7/10. Good choice, wifee happy. I take a chicken cashew dish , 5/10 (pathetic , I know ) and was equally happy. There were many Thai clientele eating and the place was charming.\nGreat food, $70 before tip and we will be back to Archi's but I believe the other locations will give me a beer to quench my thirst and alleviate the heat should I go crazy and order my food 6/10 on the spice meter."


## 预处理数据

下载数据集到本地后，使用 Tokenizer 来处理文本，对于长度不等的输入数据，可以使用填充（padding）和截断（truncation）策略来处理。

Datasets 的 `map` 方法，支持一次性在整个数据集上应用预处理函数。

下面使用填充到最大长度的策略，处理整个数据集：

In [10]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("/home/rr-ai/huggingface_models/bert-base-cased/")

def tokenize_function(examples):
    # return tokenizer(examples["text"], padding="max_length", truncation=True)
    return tokenizer(examples["text"], padding="max_length", max_length = 512, truncation=True)


tokenized_datasets = dataset.map(tokenize_function, batched=True)

Map: 100%|██████████████████████████████████████████████████████████████████████████████| 50000/50000 [00:05<00:00, 9925.95 examples/s]


In [11]:
show_random_elements(tokenized_datasets["train"], num_examples=1)

Unnamed: 0,label,text,input_ids,token_type_ids,attention_mask
0,3 stars,"I've made it a point to always go to a champagne brunch everytime I go to Vegas. My first champagne brunch encounter was at Paris. Last time we went to the Buffet at Wynn, and this time it was time to try Rio's World Buffet. They are all good in their own way - this one was HUGE though. I had to walk a long ways to get to the end of it, and it wasn't half bad. I'm going to say that I still like Paris best though. The atmosphere there is really nice, and I think the desserts there are the best. Rio's buffet had a lot of different foods that the other buffets don't have though, since it's supposed to be an International Buffet, and so there are different stations for different countries. Pretty good.","[101, 146, 112, 1396, 1189, 1122, 170, 1553, 1106, 1579, 1301, 1106, 170, 17673, 9304, 14480, 1451, 4974, 146, 1301, 1106, 6554, 119, 1422, 1148, 17673, 9304, 14480, 8107, 1108, 1120, 2123, 119, 4254, 1159, 1195, 1355, 1106, 1103, 139, 9435, 2105, 1120, 160, 20906, 117, 1105, 1142, 1159, 1122, 1108, 1159, 1106, 2222, 5470, 112, 188, 1291, 139, 9435, 2105, 119, 1220, 1132, 1155, 1363, 1107, 1147, 1319, 1236, 118, 1142, 1141, 1108, 145, 2591, 16523, 1463, 119, 146, 1125, 1106, 2647, 170, 1263, 3242, 1106, 1243, 1106, 1103, 1322, 1104, 1122, 117, 1105, 1122, 1445, 112, 189, 1544, ...]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]"


### 数据抽样

使用 1000 个数据样本，在 BERT 上演示小规模训练（基于 Pytorch Trainer）

`shuffle()`函数会随机重新排列列的值。如果您希望对用于洗牌数据集的算法有更多控制，可以在此函数中指定generator参数来使用不同的numpy.random.Generator。

In [12]:
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))

## 微调训练配置

### 加载 BERT 模型

警告通知我们正在丢弃一些权重（`vocab_transform` 和 `vocab_layer_norm` 层），并随机初始化其他一些权重（`pre_classifier` 和 `classifier` 层）。在微调模型情况下是绝对正常的，因为我们正在删除用于预训练模型的掩码语言建模任务的头部，并用一个新的头部替换它，对于这个新头部，我们没有预训练的权重，所以库会警告我们在用它进行推理之前应该对这个模型进行微调，而这正是我们要做的事情。

In [13]:
from transformers import AutoModelForSequenceClassification

# model = AutoModelForSequenceClassification.from_pretrained("/home/rr-ai/huggingface_models/bert-base-cased/", num_labels=5)
model = AutoModelForSequenceClassification.from_pretrained("/home/rr-ai/python-project/LLM-quickstart/transformers/models/bert-base-cased-finetune-yelp/checkpoint-60500/", num_labels=5)

### 训练超参数（TrainingArguments）

完整配置参数与默认值：https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/trainer#transformers.TrainingArguments

源代码定义：https://github.com/huggingface/transformers/blob/v4.36.1/src/transformers/training_args.py#L161

**最重要配置：模型权重保存路径(output_dir)**

In [14]:
from transformers import TrainingArguments

model_dir = "models/bert-base-cased-finetune-yelp"

# logging_steps 默认值为500，根据我们的训练数据和步长，将其设置为100
training_args = TrainingArguments(output_dir=model_dir,
                                  per_device_train_batch_size=16,
                                  num_train_epochs=3,
                                  logging_steps=100,
                                 report_to=[])

In [15]:
# 完整的超参数配置
print(training_args)

TrainingArguments(
_n_gpu=1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_persistent_workers=False,
dataloader_pin_memory=True,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
dispatch_batches=None,
do_eval=False,
do_predict=False,
do_train=False,
eval_accumulation_steps=None,
eval_delay=0,
eval_steps=None,
evaluation_strategy=IntervalStrategy.NO,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
fsdp=[],
fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},
fsdp_min_num_params=0,
fsdp_transformer_layer_cls_to_wrap=None,
full_determinism=False,
gradient_accumulation_steps=1,
gradient_checkpointing=False,
gradient_checkpointing_kwargs=None,
greater_is_better=

### 训练过程中的指标评估（Evaluate)

**[Hugging Face Evaluate 库](https://huggingface.co/docs/evaluate/index)** 支持使用一行代码，获得数十种不同领域（自然语言处理、计算机视觉、强化学习等）的评估方法。 当前支持 **完整评估指标：https://huggingface.co/evaluate-metric**

训练器（Trainer）在训练过程中不会自动评估模型性能。因此，我们需要向训练器传递一个函数来计算和报告指标。 

Evaluate库提供了一个简单的准确率函数，您可以使用`evaluate.load`函数加载

In [16]:
import numpy as np
import evaluate

metric = evaluate.load("/home/rr-ai/huggingface_models/evaluate_metric/accuracy")

2024-01-28 07:24:02.680012: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-01-28 07:24:02.698640: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-28 07:24:02.698661: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-28 07:24:02.699426: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-01-28 07:24:02.702816: I tensorflow/core/platform/cpu_feature_guar


接着，调用 `compute` 函数来计算预测的准确率。

在将预测传递给 compute 函数之前，我们需要将 logits 转换为预测值（**所有Transformers 模型都返回 logits**）。

In [17]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

#### 训练过程指标监控

通常，为了监控训练过程中的评估指标变化，我们可以在`TrainingArguments`指定`evaluation_strategy`参数，以便在 epoch 结束时报告评估指标。

In [18]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(output_dir=model_dir,
                                  evaluation_strategy="epoch", 
                                  per_device_train_batch_size=16,
                                  num_train_epochs=3,
                                  logging_steps=30)

## 开始训练

### 实例化训练器（Trainer）

`kernel version` 版本问题：暂不影响本示例代码运行

In [19]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
)

In [132]:
len(small_train_dataset['input_ids'][2])

512

In [133]:
len(small_train_dataset['token_type_ids'][2])

512

In [134]:
len(small_train_dataset['attention_mask'][2])

512

In [None]:
small_train_dataset['text'][0]

In [136]:
small_eval_dataset

Dataset({
    features: ['label', 'text', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 1000
})

## 使用 nvidia-smi 查看 GPU 使用

为了实时查看GPU使用情况，可以使用 `watch` 指令实现轮询：`watch -n 1 nvidia-smi`:

```shell
Every 1.0s: nvidia-smi                                                   Wed Dec 20 14:37:41 2023

Wed Dec 20 14:37:41 2023
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  Tesla T4                       Off | 00000000:00:0D.0 Off |                    0 |
| N/A   64C    P0              69W /  70W |   6665MiB / 15360MiB |     98%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A     18395      C   /root/miniconda3/bin/python                6660MiB |
+---------------------------------------------------------------------------------------+
```

In [137]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,1.2528,1.056195,0.549
2,0.8822,0.958202,0.598
3,0.6211,0.980633,0.604


TrainOutput(global_step=189, training_loss=0.9535309423214544, metrics={'train_runtime': 49.1149, 'train_samples_per_second': 61.081, 'train_steps_per_second': 3.848, 'total_flos': 789354427392000.0, 'train_loss': 0.9535309423214544, 'epoch': 3.0})

In [138]:
small_test_dataset = tokenized_datasets["test"].shuffle(seed=64).select(range(100))

In [139]:
trainer.evaluate(small_test_dataset)

{'eval_loss': 1.0549412965774536,
 'eval_accuracy': 0.54,
 'eval_runtime': 0.4308,
 'eval_samples_per_second': 232.112,
 'eval_steps_per_second': 30.175,
 'epoch': 3.0}

### 保存模型和训练状态

- 使用 `trainer.save_model` 方法保存模型，后续可以通过 from_pretrained() 方法重新加载
- 使用 `trainer.save_state` 方法保存训练状态

In [140]:
trainer.save_model(model_dir)

In [141]:
trainer.save_state()

In [142]:
# trainer.model.save_pretrained("./")

## Homework: 使用完整的 YelpReviewFull 数据集训练，看 Acc 最高能到多少

In [154]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(output_dir=model_dir,
                                  evaluation_strategy="epoch", 
                                  per_device_train_batch_size=32,
                                  num_train_epochs=3,
                                  save_steps=5000,
                                  logging_steps=30)

In [155]:
train_dataset = tokenized_datasets["train"].shuffle(seed=42)
eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(10000))

In [156]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.7642,0.739845,0.6753


Checkpoint destination directory models/bert-base-cased-finetune-yelp/checkpoint-500 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory models/bert-base-cased-finetune-yelp/checkpoint-1000 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory models/bert-base-cased-finetune-yelp/checkpoint-1500 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory models/bert-base-cased-finetune-yelp/checkpoint-2000 already exists and is non-empty.Saving will proceed but saved results may be invalid.


In [20]:
test_dataset = tokenized_datasets["test"].shuffle(seed=64).select(range(10000))

In [21]:
trainer.evaluate(test_dataset)

{'eval_loss': 0.7332530617713928,
 'eval_accuracy': 0.6952,
 'eval_runtime': 42.9271,
 'eval_samples_per_second': 232.953,
 'eval_steps_per_second': 29.119}

In [22]:
trainer.save_state()