In [None]:
# cd to your workstation, if necessary
%cd /home/aistudio/work/

In [2]:
# 导入所需的第三方库
import numpy as np
import time
import json
import paddle
import paddle.nn.functional as F

In [None]:
!pip install --upgrade paddlenlp==2.1

In [4]:
# 导入paddlenlp所需的相关包
import paddlenlp as ppnlp
from paddlenlp.data import Tuple, Pad

In [5]:
# 使用roberta-wwm-ext-large模型
# MODEL_NAME = "roberta-wwm-ext-large"
# 从保存的参数中读取
MODEL_NAME = 'test_0.83_finished'
# 只需指定想要使用的模型名称和文本分类的类别数即可完成Fine-tune网络定义，通过在预训练模型后拼接上一个全连接网络（Full Connected）进行分类
model = ppnlp.transformers.RobertaForSequenceClassification.from_pretrained(MODEL_NAME, num_classes=8) # 此次分类任务为8分类任务，故num_classes设置为8
# 定义模型对应的tokenizer，tokenizer可以把原始输入文本转化成模型model可接受的输入数据格式。需注意tokenizer类要与选择的模型相对应，具体可以查看PaddleNLP相关文档
tokenizer = ppnlp.transformers.RobertaTokenizer.from_pretrained(MODEL_NAME)

In [6]:
# 定义模型预测函数
@paddle.no_grad()
def predict(model, data, tokenizer, label_map, batch_size=1):
    encoded_inputs = tokenizer(text=data, max_seq_len=256)  # tokenizer处理为模型可接受的格式 
    example = (encoded_inputs["input_ids"], encoded_inputs["token_type_ids"])

    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # input id
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # segment id
    ): fn(samples)

    model.eval()

    input_ids, segment_ids = batchify_fn([example])
    input_ids = paddle.to_tensor(input_ids)
    segment_ids = paddle.to_tensor(segment_ids)
    logits = model(input_ids, segment_ids)
    probs = F.softmax(logits, axis=1)
    idx_top3 = paddle.argsort(probs, axis=1, descending=True).numpy()[0][:3]
    idx_top3 = idx_top3.tolist()
    labels = [label_map[i] for i in idx_top3]

    return labels  # 返回预测结果

In [7]:
# 定义要进行分类的类别
with open(MODEL_NAME + '/label_map.json') as f:
    label_map_str = json.load(f)
label_map = {}
for k, v in label_map_str.items():
    label_map[int(k)] = v

In [8]:
# 需要进行预测的文本
text = '...'

In [9]:
# 对测试集进行预测
result = predict(model, text, tokenizer, label_map)

In [None]:
print(result)