In [None]:
import os
os.environ["TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL"] = "1"

In [None]:
from transformers import AutoTokenizer, AutoModelForMultipleChoice, TrainingArguments, Trainer, pipeline
from datasets import DatasetDict
import evaluate

In [None]:
datasets = DatasetDict.load_from_disk("./c3")
datasets.pop("test")

In [None]:
datasets["validation"]

In [None]:
tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-macbert-base")

In [None]:
import numpy as np
def process_function(examples, tokenizer=tokenizer):
    context = []
    question_choice = []
    labels = []

    for batch_idx in range(len(examples["context"])):
        cur_context = "\n".join(examples["context"][batch_idx])
        question = examples["question"][batch_idx]
        choices = examples["choice"][batch_idx]
        for choice in choices:
            context.append(cur_context)
            question_choice.append(question + " " + choice)
        for _ in range(len(choices), 4):
            context.append(cur_context)
            question_choice.append(question + " 不知道")
        labels.append(choices.index(examples["answer"][batch_idx]))

    inputs = tokenizer(context, question_choice, truncation="only_first", max_length=256, padding="max_length")

    inputs = {k: np.array(v).reshape(-1, 4, 256) for k, v in inputs.items()}
    inputs["labels"] = np.array(labels)

    return inputs


In [None]:
test_datasets = datasets["train"].select(range(2)).map(process_function, batched=True, remove_columns=datasets["train"].column_names)
test_datasets

In [None]:
np.array(test_datasets["input_ids"]).shape

In [None]:
tokenized_datasets = datasets.map(process_function, batched=True, remove_columns=datasets["train"].column_names)

In [None]:
model = AutoModelForMultipleChoice.from_pretrained("hfl/chinese-macbert-base").to("cuda")

In [None]:
accuracy = evaluate.load("accuracy")

In [None]:
def compute_metric(pred):
    predictions, labels = pred
    predictions = np.argmax(predictions, axis=-1)
    return accuracy.compute(predictions=predictions, references=labels)

In [None]:
args = TrainingArguments(
    output_dir="./muliple_choice",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=1,
    logging_steps=50,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    fp16=True
)

In [None]:
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    compute_metrics=compute_metric
)

In [None]:
trainer.evaluate(tokenized_datasets["validation"].select(range(1)))

In [None]:
from typing import List
import torch

class MultipleChoicePipeline:
    def __init__(self, model, tokenizer) -> None:
        self.model = model
        self.tokenizer = tokenizer
        self.device = model.device

    def preprocess(self, context, question, choices):
        contexts = []
        question_choice = []
        context = "\n".join(context) if isinstance(context, list) else context
        for choice in choices:
            contexts.append(context)
            question_choice.append(question + " " + choice)

        return self.tokenizer(contexts, question_choice, max_length=256, truncation="only_first", return_tensors="pt", padding=True)

    def predict(self, inputs):
        inputs = {k: v.unsqueeze(0).to(self.device) for k, v in inputs.items()}
        return self.model(**inputs).logits

    def postprocess(self, logits, choices):
        predict = torch.argmax(logits, dim=-1).item()
        return choices[predict]

    def __call__(self, context: str | List[str], question: str, choices: List[str]) -> str:
        inputs = self.preprocess(context, question, choices)
        logits = self.predict(inputs)
        return self.postprocess(logits, choices)


In [None]:
pipe = MultipleChoicePipeline(model, tokenizer)

In [None]:
pipe("小明在北京上班", "小明在哪里上班？", ["北京", "上海", "河北1", "海南", "河北", "海南"])