In [None]:
import os
os.environ["http_proxy"] = "http://127.0.0.1:8889"
os.environ["https_proxy"] = "http://127.0.0.1:8889"

In [None]:
 # 1 Import
from datasets import load_dataset, load_from_disk
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer, DefaultDataCollator


In [None]:
# 2 Load dataset
# datasets = load_dataset('cmrc2018', cache_dir="./datasets")
# datasets.save_to_disk("./datasets/local/cmrc2018")
datasets = load_from_disk("./datasets/local/cmrc2018")

In [None]:
datasets["train"][0]

In [None]:
# 3 data preprocessing
tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-macbert-base")
tokenizer

In [None]:
sample_dataset = datasets["train"].select(range(10))

In [None]:
tokenized_samples = tokenizer(text=sample_dataset["question"],
                              text_pair=sample_dataset["context"],
                              max_length=384,
                              return_overflowing_tokens=True,
                              stride=128,
                              return_offsets_mapping=True,
                              truncation="only_second",
                              padding="max_length",)

tokenized_samples.keys()

In [None]:
tokenized_samples.overflow_to_sample_mapping

In [None]:
len(tokenized_samples.overflow_to_sample_mapping)

In [None]:
for sen in tokenizer.batch_decode(tokenized_samples["input_ids"][:4]):
    print(sen)

In [None]:
sample_mapping = tokenized_samples.pop("overflow_to_sample_mapping")

In [None]:
for idx, _ in enumerate(sample_mapping):
    answer = sample_dataset["answers"][sample_mapping[idx]]
    start_char = answer["answer_start"][0]
    end_char = start_char + len(answer["text"][0])
    # locate the token position of answer
    
    context_start = tokenized_samples.sequence_ids(idx).index(1)
    context_end = tokenized_samples.sequence_ids(idx).index(None, context_start) - 1
    
    offset = tokenized_samples["offset_mapping"][idx]
    
    # if the answer is in the context
    if offset[context_end][1] < start_char or offset[context_start][0] > end_char:
        start_token_pos = 0
        end_token_pos = 0
    else:
        token_id = context_start
        while token_id <= context_end and offset[token_id][0] < start_char:
            token_id += 1
        start_token_pos = token_id 
            
        token_id = context_end
        while token_id >= start_token_pos and offset[token_id][1] > end_char:
            token_id -= 1
        end_token_pos = token_id
    
    print(answer, start_char, end_char, context_start, context_end, start_token_pos, end_token_pos) 
    print("decode:", tokenizer.decode(tokenized_samples["input_ids"][idx][start_token_pos:end_token_pos+1]))
    

In [None]:
def process_func(examples):
    tokenized_examples = tokenizer(text=examples["question"],
                              text_pair=examples["context"],
                              max_length=384,
                              return_offsets_mapping=True,
                              return_overflowing_tokens=True,
                              stride=128,
                              truncation="only_second",
                              padding="max_length",)
    
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    
    start_positions = []
    end_positions = []
    
    example_ids = []
    
    for idx, _ in enumerate(sample_mapping):
        answer = examples["answers"][sample_mapping[idx]]
        start_char = answer["answer_start"][0]
        end_char = start_char + len(answer["text"][0])
        # locate the token position of answer
        context_start = tokenized_examples.sequence_ids(idx).index(1)
        context_end = tokenized_examples.sequence_ids(idx).index(None, context_start) - 1
        
        offset = tokenized_examples["offset_mapping"][idx]
        
        # if the answer is in the context
        if offset[context_end][1] < start_char or offset[context_start][0] > end_char:
            start_token_pos = 0
            end_token_pos = 0
        else:
            token_id = context_start
            while token_id <= context_end and offset[token_id][0] < start_char:
                token_id += 1
            start_token_pos = token_id 
                
            token_id = context_end
            while token_id >= context_end and offset[token_id][1] > end_char:
                token_id -= 1
            end_token_pos = token_id
            
        start_positions.append(start_token_pos)
        end_positions.append(end_token_pos)
        example_ids.append(examples["id"][sample_mapping[idx]])
        tokenized_examples["offset_mapping"][idx] = [
            (o if tokenized_examples.sequence_ids(idx)[k] == 1 else None)
            for k, o in enumerate(tokenized_examples["offset_mapping"][idx])
        ]
    
    tokenized_examples["start_positions"] = start_positions
    tokenized_examples["end_positions"] = end_positions
    tokenized_examples["example_ids"] = example_ids
    
    return tokenized_examples
    
    

In [None]:
tokenized_datasets = datasets.map(process_func, batched=True, remove_columns=datasets["train"].column_names)

In [None]:
tokenized_datasets

In [None]:
tokenized_datasets["train"]["offset_mapping"][0]

In [None]:
# get model output
import numpy as np
import collections

def get_result(start_logits, end_logits, examples, features):
    predictions = {}
    references = {}
    
    # example id to features
    example_to_features = collections.defaultdict(list)
    for idx, example_id in enumerate(features["example_ids"]):
        example_to_features[example_id].append(idx)
        
    # best answer candidate
    n_best = 2
    # max answer length
    max_answer_length = 30
    
    for example in examples:
        example_id = example["id"]
        context = example["context"]
        answers =[]
        for feature_idx in example_to_features[example_id]:
            start_logit = start_logits[feature_idx]
            end_logit = end_logits[feature_idx]
            offset = features[feature_idx]["offset_mapping"]
            start_indexes = np.argsort(start_logit)[::-1][:n_best].tolist()
            end_indexes = np.argsort(end_logit)[::-1][:n_best].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    if offset[start_index] is None or offset[end_index] is None:
                        continue
                    if end_index < start_index:
                        continue
                    if end_index - start_index + 1 > max_answer_length:
                        continue
                    answers.append({
                        "text":context[offset[start_index][0]:offset[end_index][1]],
                        "score": start_logit[start_index] + end_logit[end_index],
                    })
        if len(answers) > 0:
            best_answer = max(answers, key=lambda x: x["score"])
            predictions[example_id] = best_answer["text"]
        else:
            predictions[example_id] = ""
            
        references[example_id] = example["answers"]["text"]
        
    return predictions, references


In [None]:
# 4 Evaluation function
from cmrc_eval import evaluate_cmrc

def metric(pred):
    start_logits, end_logits = pred[0]
    if start_logits.shape[0] == len(tokenized_datasets["validation"]):
        p, r = get_result(start_logits=start_logits,
                          end_logits=end_logits,
                          examples=datasets["validation"],
                          features=tokenized_datasets["validation"])
    else:
        p, r = get_result(start_logits=start_logits,
                          end_logits=end_logits,
                          examples=datasets["test"],
                          features=tokenized_datasets["test"])
    print('123')
    return evaluate_cmrc(p, r)

In [None]:
# 5 Create Model
model = AutoModelForQuestionAnswering.from_pretrained("hfl/chinese-macbert-base")

In [None]:
# 6 TrainArgs

args = TrainingArguments(
    output_dir="./models/models_for_qa",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    evaluation_strategy="steps",
    eval_steps=100,
    save_strategy="epoch",
    # load_best_model_at_end=True,
    logging_steps=50,
    logging_strategy="steps",
    num_train_epochs=3,
)

In [None]:
# 7 Trainer

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    data_collator=DefaultDataCollator(),
    compute_metrics=metric,
)

import torch
torch.cuda.empty_cache()


In [None]:
# 8 Train
trainer.train()

In [None]:
# 8 pipeline

from transformers import pipeline
pipe = pipeline("question-answering", model=model, tokenizer=tokenizer, device=0)
pipe