In [None]:
import os
import torch

os.environ["TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL"] = "1"

In [None]:
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, DefaultDataCollator, TrainingArguments, Trainer, \
    pipeline
from datasets import load_dataset

In [None]:
datasets = load_dataset("cmrc2018")
datasets

In [None]:
datasets["train"][0]

In [None]:
tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-macbert-base")

In [None]:
import numpy as np

# seed = np.random.randint(0, 2**31 - 1)
# print(seed)
sample_dataset = datasets["train"].select(range(2))
sample_dataset.column_names

In [None]:
sample_dataset["question"]

In [None]:
tokenized_examples = tokenizer(
    text=list(sample_dataset["question"]),
    text_pair=list(sample_dataset["context"]),
    return_offsets_mapping=True,
    return_overflowing_tokens=True,
    return_tensors="pt",
    stride=128,
    max_length=328, truncation="only_second", padding="max_length"
)

In [None]:
tokenized_examples.keys()

In [None]:
print(tokenized_examples["input_ids"].shape)
print(tokenized_examples["offset_mapping"].shape)
print(tokenized_examples["overflow_to_sample_mapping"].shape)

In [None]:
tokenized_examples["overflow_to_sample_mapping"]

In [None]:
tokenized_examples["offset_mapping"][0]

In [None]:
print(tokenized_examples.sequence_ids(0))

In [None]:
len(sample_dataset["context"])

In [None]:
len(tokenized_examples["input_ids"])

In [None]:
def test():
    start_positions = []
    end_positions = []

    overflow_to_sample_mapping = tokenized_examples["overflow_to_sample_mapping"]

    for batch_idx in range(len(tokenized_examples["input_ids"])):
        real_batch_idx = overflow_to_sample_mapping[batch_idx].item()
        answers = sample_dataset["answers"][real_batch_idx]
        answer_char_start = answers["answer_start"][0]
        answer_char_end = answer_char_start + len(answers["text"][0]) - 1

        context_ids_start = tokenized_examples.sequence_ids(batch_idx).index(1)
        context_ids_end = tokenized_examples.sequence_ids(batch_idx).index(None, context_ids_start) - 1

        offset = tokenized_examples["offset_mapping"][batch_idx]

        answer_idx_start = None
        answer_idx_end = None

        print(sample_dataset["answers"][real_batch_idx])
        # print(sample_dataset["context"][real_batch_idx])
        #
        # print(offset[context_ids_start][0])
        # print(offset[context_ids_end][1])
        # print(tokenizer.decode(tokenized_examples["input_ids"][batch_idx][offset[context_ids_start][0]:offset[context_ids_end][1]]))

        if offset[context_ids_start][0] <= answer_char_start and answer_char_end < offset[context_ids_end][1]:
            for idx in range(context_ids_start, context_ids_end + 1):
                cur_offset = offset[idx]
                if answer_idx_start is None and cur_offset[0] <= answer_char_start < cur_offset[1]:
                    answer_idx_start = idx
                find_end = False
                if answer_idx_start is not None and cur_offset[0] <= answer_char_end < cur_offset[1]:
                    answer_idx_end = idx
                    find_end = True
                if answer_idx_start is not None and answer_idx_end is not None and find_end is False:
                    break

        if answer_idx_start is not None and answer_idx_end is not None:
            start_positions.append(answer_idx_start)
            end_positions.append(answer_idx_end)

            print(tokenizer.decode(tokenized_examples["input_ids"][batch_idx][answer_idx_start:answer_idx_end + 1]))
        else:
            start_positions.append(0)
            end_positions.append(0)

    print(start_positions)
    print(end_positions)


test()

In [None]:
def process_function(examples, tokenizer=tokenizer):
    inputs = tokenizer(
        text=examples["question"],
        text_pair=examples["context"],
        return_offsets_mapping=True,
        return_overflowing_tokens=True,
        stride=128,
        max_length=328, truncation="only_second", padding="max_length"
    )

    offset_mapping = inputs["offset_mapping"]
    overflow_to_sample_mapping = inputs.pop("overflow_to_sample_mapping")

    start_positions = []
    end_positions = []
    example_ids = []

    for batch_idx in range(len(inputs["input_ids"])):
        cls_index = inputs["input_ids"][batch_idx].index(tokenizer.cls_token_id)

        real_batch_idx = overflow_to_sample_mapping[batch_idx]
        answers = examples["answers"][real_batch_idx]
        answer_char_start = answers["answer_start"][0]
        answer_char_end = answer_char_start + len(answers["text"][0]) - 1

        context_ids_start = inputs.sequence_ids(batch_idx).index(1)
        context_ids_end = inputs.sequence_ids(batch_idx).index(None, context_ids_start) - 1

        offset = offset_mapping[batch_idx]

        answer_idx_start = None
        answer_idx_end = None

        if offset[context_ids_start][0] <= answer_char_start and answer_char_end < offset[context_ids_end][1]:
            for idx in range(context_ids_start, context_ids_end + 1):
                cur_offset = offset[idx]
                if answer_idx_start is None and cur_offset[0] <= answer_char_start < cur_offset[1]:
                    answer_idx_start = idx
                find_end = False
                if answer_idx_start is not None and cur_offset[0] <= answer_char_end < cur_offset[1]:
                    answer_idx_end = idx
                    find_end = True
                if answer_idx_start is not None and answer_idx_end is not None and find_end is False:
                    break

        if answer_idx_start is not None and answer_idx_end is not None:
            start_positions.append(answer_idx_start)
            end_positions.append(answer_idx_end)
        else:
            start_positions.append(cls_index)
            end_positions.append(cls_index)

        example_ids.append(examples["id"][real_batch_idx])
        inputs["offset_mapping"][batch_idx] = [
            (v if inputs.sequence_ids(batch_idx)[i] == 1 else None)
            for i, v in enumerate(inputs["offset_mapping"][batch_idx])
        ]

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    inputs["example_ids"] = example_ids

    return inputs

In [None]:
test_datasets = datasets["train"].select(range(10)).map(process_function, batched=True, remove_columns=datasets["train"].column_names)
test_datasets

In [None]:
processed_datasets = datasets.map(process_function, batched=True, remove_columns=datasets["train"].column_names)

In [None]:
processed_datasets["train"]

In [None]:
def metrics(pred):
    start_logits, end_logits = pred[0]
    print(pred[1])
    print(start_logits.shape)
    print(end_logits.shape)
    return {}

In [None]:
model = AutoModelForQuestionAnswering.from_pretrained("hfl/chinese-macbert-base").to("cuda")

In [None]:
args = TrainingArguments(
    output_dir="checkpoints-for-mrc",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_steps=50,
    num_train_epochs=3,
    report_to=["tensorboard"]
)

In [None]:
processed_datasets["train"].select(range(int(len(processed_datasets["train"]) / 4)))

In [None]:
processed_datasets["validation"].select(range(int(len(processed_datasets["validation"]) / 4)))

In [None]:
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=processed_datasets["train"].shuffle(seed=42).select(
        range(int(len(processed_datasets["train"]) / 10))),
    eval_dataset=processed_datasets["validation"].shuffle(seed=42).select(
        range(int(len(processed_datasets["validation"]) / 10))),
    data_collator=DefaultDataCollator(),
    compute_metrics=metrics
)

In [None]:
trainer.train()

In [None]:
trainer.evaluate(processed_datasets["test"].select(range(1)))

In [None]:
trainer.save_model()

In [None]:
model = AutoModelForQuestionAnswering.from_pretrained("./checkpoints-for-mrc", local_files_only=True)
tokenizer = AutoTokenizer.from_pretrained("./checkpoints-for-mrc", local_files_only=True)

In [None]:
pipe = pipeline("question-answering", model=model, tokenizer=tokenizer)

In [None]:
pipe(question="小明在哪里上班？", context="小明在北京上班。")