In [1]:
import evaluate
import numpy as np
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
)
from transformers import TrainingArguments

import torch
from torch import nn
from transformers import Trainer


import os
os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'
os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'

In [None]:
# AG_News 英文分类数据集
# ds = load_dataset("fancyzhx/ag_news")

## 中文分类数据集
ds = load_dataset("lansinuote/ChnSentiCorp")

model_name = "bert-base-chinese"

tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    trust_remote_code=True,
)

bert = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    trust_remote_code=True,
    num_labels=2,
)

In [3]:
def tokenize_func(item):
    global tokenizer
    tokenized_inputs = tokenizer(
        item["text"],
        max_length=512,
        truncation=True,
    )
    return tokenized_inputs

tokenized_datasets = ds.map(
    tokenize_func,
    batched=True,
)

In [None]:
# 通过下述命令，查看 trainer.save_model 保存的是否是最好的模型权重。
# 通过md5值和sha1判断是否为同一个文件

# !find . -type f -name "*.safetensors" -exec sha1sum {} \;

In [18]:
from dataclasses import dataclass

@dataclass
class BertCLS:
    def __init__(self, model, train_dataset=None, eval_dataset=None, output_dir="output", epoch=3):
        self.model = model
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.args = self.get_args(output_dir, epoch)
        from transformers import DataCollatorWithPadding
        self.data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
        self.trainer = Trainer(
            model=self.model,
            args=self.args,
            train_dataset=self.train_dataset,
            eval_dataset=self.eval_dataset,
            data_collator=self.data_collator,
            # compute_metrics=compute_metrics,
            tokenizer=tokenizer,
        )
        
    def get_args(self, output_dir, epoch):
        args = TrainingArguments(
            output_dir=output_dir,
            evaluation_strategy="epoch",
            save_strategy="epoch",
            save_total_limit=3,
            learning_rate=2e-5,
            num_train_epochs=epoch,
            weight_decay=0.01,
            per_device_train_batch_size=32,
            per_device_eval_batch_size=16,
            # logging_steps=16,
            save_safetensors=True,
            overwrite_output_dir=True,
            load_best_model_at_end=True,
        )
        return args
    
    def set_args(self, args):
        """
            从外部重新设置 TrainingArguments，args 更新后，trainer也进行更新
        """
        self.args = args
        
        self.trainer = Trainer(
            model=self.model,
            args=self.args,
            train_dataset=self.train_dataset,
            eval_dataset=self.eval_dataset,
            data_collator=self.data_collator,
            # compute_metrics=compute_metrics,
            tokenizer=tokenizer,
        )
        
    def train(self, over_write=False):
        best_model_path = os.path.join(self.args.output_dir, "best_model")
        
        if over_write:
            self.trainer.train()
            self.trainer.save_model()
        elif not os.path.exists(best_model_path):
            self.trainer.train()
            self.trainer.save_model()
        else:
            print(f"预训练权重 {best_model_path} 已存在，且over_write={over_write}。不启动模型训练！")

    def eval(self, eval_dataset):
        predictions = self.trainer.predict(eval_dataset)
        preds = np.argmax(predictions.predictions, axis=-1)
        metric = evaluate.load("glue", "mrpc")
        return metric.compute(predictions=preds, references=predictions.label_ids)
    
    def pred(self, pred_dataset):
        predictions = self.trainer.predict(pred_dataset)
        preds = np.argmax(predictions.predictions, axis=-1)
        return pred_dataset.add_column("pred", preds)

In [12]:
tokenized_datasets

Dataset({
    features: ['text', 'label'],
    num_rows: 1200
})

In [13]:
bert_cls = BertCLS(
    model=bert,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
)



In [15]:
bert_cls.eval(tokenized_datasets["test"])

{'accuracy': 0.5075, 'f1': 0.6729385722191478}

In [16]:
bert_cls.train(over_write=True)

Epoch,Training Loss,Validation Loss,Model Preparation Time
1,No log,0.288048,0.0059
2,No log,0.259681,0.0059
3,No log,0.260843,0.0059


In [17]:
bert_cls.eval(tokenized_datasets["test"])

{'accuracy': 0.9566666666666667, 'f1': 0.9577922077922078}

## 加载 best_model

不训练模型，加载本地模型

In [21]:
bert_cls = BertCLS(
    model=AutoModelForSequenceClassification.from_pretrained("output/best_model"),
)



In [22]:
bert_cls.eval(tokenized_datasets["test"])

{'accuracy': 0.9341666666666667, 'f1': 0.9341117597998332}