In [17]:
import os
import evaluate
import numpy as np
from datasets import load_dataset
from transformers import (
    Trainer,
    TrainingArguments,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
)
from dataclasses import dataclass

In [5]:
os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'
os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'

In [4]:
@dataclass
class BertCLS:
    def __init__(
        self,
        model,
        tokenizer,
        train_dataset=None,
        eval_dataset=None,
        output_dir="output",
        epoch=3,
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset

        self.data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

        self.args = self.get_args(output_dir, epoch)

        self.trainer = Trainer(
            model=self.model,
            args=self.args,
            train_dataset=self.train_dataset,
            eval_dataset=self.eval_dataset,
            data_collator=self.data_collator,
            # compute_metrics=compute_metrics,
            tokenizer=tokenizer,
        )

    def get_args(self, output_dir, epoch):
        if self.eval_dataset:
            args = TrainingArguments(
                output_dir=output_dir,
                evaluation_strategy="epoch",
                save_strategy="epoch",
                save_total_limit=3,
                learning_rate=2e-5,
                num_train_epochs=epoch,
                weight_decay=0.01,
                per_device_train_batch_size=32,
                per_device_eval_batch_size=16,
                # logging_steps=16,
                save_safetensors=True,
                overwrite_output_dir=True,
                load_best_model_at_end=True,
            )
        else:
            args = TrainingArguments(
                output_dir=output_dir,
                evaluation_strategy="no",
                save_strategy="epoch",
                save_total_limit=3,
                learning_rate=2e-5,
                num_train_epochs=epoch,
                weight_decay=0.01,
                per_device_train_batch_size=32,
                per_device_eval_batch_size=16,
                # logging_steps=16,
                save_safetensors=True,
                overwrite_output_dir=True,
                # load_best_model_at_end=True,
            )
        return args

    def set_args(self, args):
        """
        从外部重新设置 TrainingArguments，args 更新后，trainer也进行更新
        """
        self.args = args

        self.trainer = Trainer(
            model=self.model,
            args=self.args,
            train_dataset=self.train_dataset,
            eval_dataset=self.eval_dataset,
            data_collator=self.data_collator,
            # compute_metrics=compute_metrics,
            tokenizer=self.tokenizer,
        )

    def train(self, epoch=None, over_write=False):
        if epoch:
            self.args.num_train_epochs = epoch
            
        best_model_path = os.path.join(self.args.output_dir, "best_model")

        if over_write or not os.path.exists(best_model_path):
            self.trainer.train()
            self.trainer.save_model(best_model_path)
        else:
            print(
                f"预训练权重 {best_model_path} 已存在，且over_write={over_write}。不启动模型训练！"
            )

    def eval(self, eval_dataset):
        predictions = self.trainer.predict(eval_dataset)
        preds = np.argmax(predictions.predictions, axis=-1)
        metric = evaluate.load("glue", "mrpc")
        return metric.compute(predictions=preds, references=predictions.label_ids)

    def pred(self, pred_dataset):
        predictions = self.trainer.predict(pred_dataset)
        preds = np.argmax(predictions.predictions, axis=-1)
        return pred_dataset.add_column("pred", preds)

## model

In [None]:
model_name = "bert-base-chinese"

bert = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    trust_remote_code=True,
    num_labels=2,
)

tokenizer = AutoTokenizer.from_pretrained(model_name)

## dataset

In [8]:
ds = load_dataset("lansinuote/ChnSentiCorp")

In [19]:
ds["train"][0]

{'text': '选择珠江花园的原因就是方便，有电动扶梯直接到达海边，周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般，但还算整洁。 泳池在大堂的屋顶，因此很小，不过女儿倒是喜欢。 包的早餐是西式的，还算丰富。 服务吗，一般',
 'label': 1}

In [None]:
# encoded_dataset = dataset.map(
    # preprocess_data, 
    # batched=True, 
    # remove_columns=dataset['train'].column_names)

In [20]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

tokenize

In [21]:
def tokenize_func_pad(item):
    global tokenizer
    tokenized_inputs = tokenizer(
        item["text"],
        max_length=512,
        truncation=True,
        padding="max_length"
    )
    return tokenized_inputs

In [22]:
def tokenize_func(item):
    global tokenizer
    tokenized_inputs = tokenizer(
        item["text"],
        max_length=512,
        truncation=True,
    )
    return tokenized_inputs

In [63]:
batch_dataset_pad = ds["train"].map(
    tokenize_func_pad, remove_columns=["text"]
)

Map:   0%|          | 0/9600 [00:00<?, ? examples/s]

In [31]:
batch_dataset_pad.set_format("torch")

In [34]:
batch_dataset_pad["input_ids"].shape

torch.Size([9600, 512])

In [62]:
batch_dataset = ds["train"].map(tokenize_func, remove_columns=["text"])

Map:   0%|          | 0/9600 [00:00<?, ? examples/s]

In [None]:
batch_dataset["input_ids"][:10]

In [41]:
batch_dataset.set_format("torch")

In [50]:
batch_dataset

Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 9600
})

In [56]:
data = [batch_dataset[i] for i in range(16)]
data_collator(data)["input_ids"].shape

torch.Size([16, 126])

In [61]:
batch_dataset_pad

Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 9600
})

In [66]:
BertCLS(bert, tokenizer, batch_dataset_pad, epoch=1).train()



Step,Training Loss


In [72]:
len(batch_dataset[0]["input_ids"])

105

In [73]:
BertCLS(bert, tokenizer, batch_dataset, epoch=1).train()



Step,Training Loss
