# 文本分类实例

In [1]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
import datasets

In [2]:
# 通过新方法加载数据和数据清洗
dataset = datasets.load_dataset("csv", data_files="./ChnSentiCorp_htl_all.csv", split="train")
print(len(dataset))
dataset = dataset.filter(lambda x: x["review"] is not None)
print(len(dataset))
# 划分数据集
split_dataset = dataset.train_test_split(test_size=0.1)
split_dataset

Generating train split: 0 examples [00:00, ? examples/s]

7766


Filter:   0%|          | 0/7766 [00:00<?, ? examples/s]

7765


DatasetDict({
    train: Dataset({
        features: ['label', 'review'],
        num_rows: 6988
    })
    test: Dataset({
        features: ['label', 'review'],
        num_rows: 777
    })
})

In [3]:
# 进行分词
tokenizer = AutoTokenizer.from_pretrained("hfl/rbt3")

def process_function(example):
    tokenized_example = tokenizer(example["review"], max_length=128, truncation=True)
    tokenized_example["labels"] = example["label"]
    return tokenized_example

tokenized_datasets = split_dataset.map(function=process_function, batched=True, remove_columns=split_dataset["train"].column_names)

tokenized_datasets

Map:   0%|          | 0/6988 [00:00<?, ? examples/s]

Map:   0%|          | 0/777 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 6988
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 777
    })
})

In [5]:
# 创建模型
model = AutoModelForSequenceClassification.from_pretrained("hfl/rbt3")

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at hfl/rbt3 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [16]:
# 创建评估函数
import evaluate

acc_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")

In [17]:
def eval_metric(eval_predict):      # 该参数涉及到Trainer
    predictions, labels = eval_predict      # 此时还未进行模型输出结果到预测类别的转换
    predictions = predictions.argmax(axis=-1)
    acc = acc_metric.compute(predictions=predictions, references=labels)
    f1 = f1_metric.compute(predictions=predictions, references=labels)
    acc.update(f1)
    return acc

In [None]:
# 设置TrainingArguments
train_args = TrainingArguments(output_dir="./checkpoints",
                               per_device_train_batch_size=64,
                               per_device_eval_batch_size=128,
                               logging_steps=10)
train_args

In [21]:
# 创建Trainer
from transformers import DataCollatorWithPadding

trainer = Trainer(model=model, args=train_args,
                  train_dataset=tokenized_datasets["train"],
                  eval_dataset=tokenized_datasets["test"],
                  data_collator=DataCollatorWithPadding(tokenizer),
                  compute_metrics=eval_metric,
                  )

In [22]:
# 模型训练
trainer.train()

Step,Training Loss


KeyboardInterrupt: 

In [18]:
# 模型评估
trainer.evaluate()

{'eval_loss': 0.4116431176662445,
 'eval_accuracy': 0.8996138996138996,
 'eval_f1': 0.9257142857142857,
 'eval_runtime': 1.6669,
 'eval_samples_per_second': 466.146,
 'eval_steps_per_second': 58.793,
 'epoch': 3.0}

In [19]:
# 模型预测
trainer.predict(tokenized_datasets["test"])

PredictionOutput(predictions=array([[ 2.920311 , -2.8374043],
       [ 1.7479247, -1.065124 ],
       [ 2.635941 , -2.4907978],
       ...,
       [-3.3409631,  3.8871646],
       [-3.4347029,  4.1078176],
       [ 2.9083965, -2.6232598]], dtype=float32), label_ids=array([0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0,
       1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1,
       1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1,
       1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1,
       0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1,
       1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1,
       1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1,
       1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1,
       1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
    