In [None]:
import os
os.environ["http_proxy"] = "http://127.0.0.1:8889"
os.environ["https_proxy"] = "http://127.0.0.1:8889"

In [None]:
# 1 Import
import numpy as np
import torch
import evaluate
from transformers import AutoTokenizer, AutoModelForMultipleChoice, TrainingArguments, Trainer
from datasets import load_dataset, load_from_disk

In [None]:
# datasets = load_dataset("c3","dialog")
# datasets.save_to_disk("./datasets/c3/dialog")

datasets = load_from_disk("./datasets/c3/dialog")

def extract_choice(example):
    example["choice"] = example["questions"]["choice"][0]
    example["answer"] = example["questions"]["answer"][0]
    example["question"] = example["questions"]["question"][0]
    return example

datasets = datasets.map(extract_choice, batch_size=True, remove_columns=["questions"])
datasets

In [None]:
datasets["train"][1]

In [None]:
datasets.pop("test")
datasets

In [None]:
# 2 data preprocessing
tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-macbert-base")

def process_func(example):
    context = []
    question_choice = []
    label = []
    for idx in range(len(example["documents"])):
        ctx = "\n".join(example["documents"][idx])
        quextion = example["question"][idx]
        choices = example["choice"][idx]
        for choice in choices:
            context.append(ctx)
            question_choice.append(quextion + " " + choice)
        if len(choices) < 4:
            for _ in range(4 - len(choices)):
                context.append(ctx)
                question_choice.append(quextion + " " + "无法确定")
        label.append(choices.index(example["answer"][idx]))
    
    tokenized_examples = tokenizer(context, question_choice, truncation="only_first",max_length=256, padding="max_length")
    tokenized_examples = {k: [v[i: i+4] for i in range(0, len(v), 4)] for k,v in tokenized_examples.items()} 
    tokenized_examples["label"] = label
    return tokenized_examples

In [None]:
tokenized_c3 = datasets.map(process_func, batched=True)

In [None]:
# 3 Create model
model = AutoModelForMultipleChoice.from_pretrained("hfl/chinese-macbert-base")

In [None]:
# 4 Evaluation
import numpy as np
accuracy =  evaluate.load("accuracy")

def comput_metric(pred):
    predictions, references = pred
    predictions = np.argmax(predictions, axis=-1)
    return accuracy.compute(predictions=predictions, references=references)

In [None]:
# 5 Train args
args = TrainingArguments(
    output_dir="./models/mc",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    num_train_epochs=3,
    logging_strategy="steps",
    logging_steps=50,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    fp16=True,
)

In [None]:
args.device

In [None]:
# 6 Trainer
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_c3["train"],
    eval_dataset=tokenized_c3["validation"],
    compute_metrics=comput_metric,
)

In [None]:
# 7 Train
trainer.train()

In [None]:
# 9 Prediction
class MultipleChoicePipeline:
    
    def __init__(self, model, tokenizer) -> None:
        self.model = model
        self.tokenizer = tokenizer
        self.device = model.device
        
    def preprocess(self, context, question, choices):
        cs, qcs = [], []
        for c in choices:
            cs.append(context)
            qcs.append(question + " " + c)
        return tokenizer(cs, qcs, truncation="only_first", max_length=256, return_tensors="pt")
    
        
    
    def predict(self, inputs):
        inputs = {k: v.unsqueeze(0).to(self.device) for k, v in inputs.items()}
        return self.model(**inputs).logits
    
    def postprocess(self, logits, choices):
        return choices[logits.argmax(dim=-1).cpu().item()]
    
    def __call__(self, context, question, choices):
        inputs = self.preprocess(context, question, choices)
        logits = self.predict(inputs)
        return self.postprocess(logits, choices)

In [None]:
pipe = MultipleChoicePipeline(model, tokenizer)


In [None]:
pipe(
    context="我是一个首尔人，我爱我的祖国",
    question="我是哪国人？",
    choices=["中国人", "美国人", "日本人", "韩国人"]
)