In [None]:
import os
os.environ["TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL"] = "1"

In [None]:
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, DefaultDataCollator, TrainingArguments, Trainer, pipeline
from datasets import load_dataset

In [None]:
datasets = load_dataset("cmrc2018")
datasets

In [None]:
datasets["train"][0]

In [None]:
tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-macbert-base")

In [None]:
sample_dataset = datasets["train"].select(range(1))
sample_dataset.column_names

In [None]:
sample_dataset["question"]

In [None]:
tokenized_examples = tokenizer(
    text=list(sample_dataset["question"]),
    text_pair=list(sample_dataset["context"]),
    return_offsets_mapping=True,
    max_length=512, truncation="only_second", padding="max_length"
)
tokenized_examples.keys()

In [None]:
print(tokenized_examples["offset_mapping"][0])

In [None]:
print(tokenized_examples.sequence_ids(0))

In [None]:
def test():
    start_positions = []
    end_positions = []

    for batch_idx, offset in enumerate(tokenized_examples["offset_mapping"]):
        answers = sample_dataset[batch_idx]["answers"]

        answer_start_char = answers["answer_start"][0]
        answer_end_char = answer_start_char + len(answers["text"][0]) - 1

        # context在input_ids里的起始和结束下标
        context_start = tokenized_examples.sequence_ids(batch_idx).index(1)
        context_end = tokenized_examples.sequence_ids(batch_idx).index(None, context_start) - 1

        answer_start = None
        answer_end = None

        if answer_start_char >= offset[context_start][0] and answer_end_char <= offset[context_end][1] - 1:
            for offset_idx in range(context_start, context_end + 1):
                char_start, char_end = offset[offset_idx]
                if char_start <= answer_start_char < char_end and answer_start is None:
                    answer_start = offset_idx
                cur_find_end = False
                if answer_start is not None and char_start <= answer_end_char < char_end:
                    answer_end = offset_idx
                    cur_find_end = True
                if answer_start is not None and answer_end is not None and cur_find_end is False:
                    break

        if answer_start is not None and answer_end is not None:
            start_positions.append(answer_start)
            end_positions.append(answer_end)
        else:
            start_positions.append(0)
            end_positions.append(0)

        print(sample_dataset[batch_idx]["context"][offset[answer_start][0]:offset[answer_end][1]])
        print(tokenizer.decode(tokenized_examples["input_ids"][batch_idx][answer_start:answer_end + 1]))

    print(start_positions)
    print(end_positions)

In [None]:
def process_function(examples, tokenizer=tokenizer):
    inputs = tokenizer(
        text=examples["question"],
        text_pair=examples["context"],
        return_offsets_mapping=True,
        max_length=512, truncation="only_second", padding="max_length"
    )

    offset_mapping = inputs.pop("offset_mapping")

    start_positions = []
    end_positions = []

    for batch_idx in range(len(inputs["input_ids"])):
        offset = offset_mapping[batch_idx]
        answers = examples["answers"][batch_idx]

        answer_start_char = answers["answer_start"][0]
        answer_end_char = answer_start_char + len(answers["text"][0]) - 1

        context_start = inputs.sequence_ids(batch_idx).index(1)
        context_end = inputs.sequence_ids(batch_idx).index(None, context_start) - 1

        answer_start = None
        answer_end = None

        if answer_start_char >= offset[context_start][0] and answer_end_char <= offset[context_end][1] - 1:
            for offset_idx in range(context_start, context_end + 1):
                char_start, char_end = offset[offset_idx]
                if char_start <= answer_start_char < char_end and answer_start is None:
                    answer_start = offset_idx
                cur_find_end = False
                if answer_start is not None and char_start <= answer_end_char < char_end:
                    answer_end = offset_idx
                    cur_find_end = True
                if answer_start is not None and answer_end is not None and cur_find_end is False:
                    break

        if answer_start is not None and answer_end is not None:
            start_positions.append(answer_start)
            end_positions.append(answer_end)
        else:
            start_positions.append(0)
            end_positions.append(0)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

In [None]:
processed_datasets = datasets.map(process_function, batched=True, remove_columns=datasets["train"].column_names)

In [None]:
processed_datasets

In [None]:
processed_datasets["train"]

In [None]:
model = AutoModelForQuestionAnswering.from_pretrained("hfl/chinese-macbert-base").to("cuda")

In [None]:
args = TrainingArguments(
    output_dir="checkpoints-for-mrc",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_steps=50,
    num_train_epochs=3,
    report_to=["tensorboard"]
)

In [None]:
processed_datasets["train"].select(range(int(len(processed_datasets["train"]) / 4)))

In [None]:
processed_datasets["validation"].select(range(int(len(processed_datasets["validation"]) / 4)))

In [None]:
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=processed_datasets["train"].shuffle(seed=42).select(range(int(len(processed_datasets["train"]) / 10))),
    eval_dataset=processed_datasets["validation"].shuffle(seed=42).select(range(int(len(processed_datasets["validation"]) / 10))),
    data_collator=DefaultDataCollator()
)

In [None]:
trainer.train()

In [None]:
trainer.save_model()

In [None]:
model = AutoModelForQuestionAnswering.from_pretrained("./checkpoints-for-mrc", local_files_only=True)
tokenizer = AutoTokenizer.from_pretrained("./checkpoints-for-mrc", local_files_only=True)

In [None]:
pipe= pipeline("question-answering", model=model, tokenizer=tokenizer)

In [None]:
pipe(question="小明在哪里上班？", context="小明在北京上班。")