# Preparing the dataset

In [1]:
from datasets import load_dataset

# 用你的本地 SQuAD 数据集文件路径替换这里的路径
local_bioasq_train_file = 'data/BioASQ/train/BioASQ-train-list.json'
local_bioasq_val_file = 'data/BioASQ/val/BioASQ-val-list.json'
local_bioasq_test_file = 'data/BioASQ/test/BioASQ-test-list.json'

# 加载本地 SQuAD 数据集
bioasq_dataset = load_dataset('json', data_files={'train': local_bioasq_train_file, 'val': local_bioasq_val_file, 'test': local_bioasq_test_file}, field='data')

# 访问数据集的子集，例如 'train' 和 'validation'
train_data = bioasq_dataset['train']
val_data = bioasq_dataset['val']
test_data = bioasq_dataset['test']

Found cached dataset json (/home/aoyuli/.cache/huggingface/datasets/json/default-90eae69d8c584b52/0.0.0/fe5dd6ea2639a6df622901539cb550cf8797e5a6b2dd7af1cf934bed8e233e6e)


  0%|          | 0/3 [00:00<?, ?it/s]

In [2]:
from datasets import Dataset, DatasetDict

def flatten_dataset(dataset):
    flattened_examples = []

    for example in dataset:
        for paragraph in example['paragraphs']:
            context = paragraph['context']
            for qa in paragraph['qas']:
                flattened_example = {
                    'question': qa['question'],
                    'context': context,
                    'answers': {
                        'text': [qa['answers'][0]['text']],
                        'answer_start': [qa['answers'][0]['answer_start']],
                    },
                    'title': 'BioASQ',
                    'id': qa['id']
                }
                flattened_examples.append(flattened_example)

    return flattened_examples

# train_examples = flatten_dataset(train_data)[:320]
# val_examples = flatten_dataset(val_data)[:40]
# test_examples = flatten_dataset(test_data)[:40]

train_examples = flatten_dataset(train_data)
val_examples = flatten_dataset(val_data)
test_examples = flatten_dataset(test_data)

train_data = Dataset.from_dict({key: [example[key] for example in train_examples] for key in train_examples[0].keys()})
val_data = Dataset.from_dict({key: [example[key] for example in val_examples] for key in val_examples[0].keys()})
test_data = Dataset.from_dict({key: [example[key] for example in test_examples] for key in val_examples[0].keys()})


In [3]:
# raw_datasets = train_data.train_test_split(test_size=0.2)
raw_datasets = DatasetDict({
    'train': train_data,
    'validation': val_data,
    'test': test_data
})

In [4]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['question', 'context', 'answers', 'title', 'id'],
        num_rows: 6878
    })
    validation: Dataset({
        features: ['question', 'context', 'answers', 'title', 'id'],
        num_rows: 859
    })
    test: Dataset({
        features: ['question', 'context', 'answers', 'title', 'id'],
        num_rows: 861
    })
})

In [5]:
print("Context: ", raw_datasets["train"][0]["context"])
print("Question: ", raw_datasets["train"][0]["question"])
print("Answer: ", raw_datasets["train"][0]["answers"])

Context:  Rivaroxaban versus warfarin in Japanese patients with nonvalvular atrial fibrillation for the secondary prevention of stroke: a subgroup analysis of J-ROCKET AF. BACKGROUND: The overall analysis of the rivaroxaban versus warfarin in Japanese patients with atrial fibrillation (J-ROCKET AF) trial revealed that rivaroxaban was not inferior to warfarin with respect to the primary safety outcome. In addition, there was a strong trend for a reduction in the rate of stroke/systemic embolism with rivaroxaban compared with warfarin. METHODS: In this subanalysis of the J-ROCKET AF trial, we investigated the consistency of safety and efficacy profile of rivaroxaban versus warfarin among the subgroups of patients with previous stroke, transient ischemic attack, or non-central nervous system systemic embolism (secondary prevention group) and those without (primary prevention group). RESULTS: Patients in the secondary prevention group were 63.6% of the overall population of J-ROCKET AF. In

In [6]:
raw_datasets["train"].filter(lambda x: len(x["answers"]["text"]) != 1)

Filter:   0%|          | 0/6878 [00:00<?, ? examples/s]

Dataset({
    features: ['question', 'context', 'answers', 'title', 'id'],
    num_rows: 0
})

In [7]:
print(raw_datasets["test"][0]["answers"])
print(raw_datasets["test"][2]["answers"])

{'answer_start': [931], 'text': ['met']}
{'answer_start': [1323], 'text': ['COL5A2']}


In [8]:
print(raw_datasets["test"][2]["context"])
print(raw_datasets["test"][2]["question"])

Clinical and genetic aspects of Ehlers-Danlos syndrome, classic type. Classic Ehlers-Danlos syndrome is a heritable connective tissue disorder characterized by skin hyperextensibility, fragile and soft skin, delayed wound healing with formation of atrophic scars, easy bruising, and generalized joint hypermobility. It comprises Ehlers-Danlos syndrome type I and Ehlers-Danlos syndrome type II, but it is now apparent that these form a continuum of clinical findings and differ only in phenotypic severity. It is currently estimated that approximately 50% of patients with a clinical diagnosis of classic Ehlers-Danlos syndrome harbor mutations in the COL5A1 and the COL5A2 gene, encoding the α1 and the α2-chain of type V collagen, respectively. However, because no prospective molecular studies of COL5A1 and COL5A2 have been performed in a clinically well-defined patient group, this number may underestimate the real proportion of patients with classic Ehlers-Danlos syndrome harboring a mutation

## Preprocessing the trainset

In [9]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

In [10]:
tokenizer.is_fast

True

In [11]:
context = raw_datasets["train"][0]["context"]
question = raw_datasets["train"][0]["question"]

inputs = tokenizer(question, context)
tokenizer.decode(inputs["input_ids"])

'[CLS] what medication were compared in the rocket af trial? [SEP] rivaroxaban versus warfarin in japanese patients with nonvalvular atrial fibrillation for the secondary prevention of stroke : a subgroup analysis of j - rocket af. background : the overall analysis of the rivaroxaban versus warfarin in japanese patients with atrial fibrillation ( j - rocket af ) trial revealed that rivaroxaban was not inferior to warfarin with respect to the primary safety outcome. in addition, there was a strong trend for a reduction in the rate of stroke / systemic embolism with rivaroxaban compared with warfarin. methods : in this subanalysis of the j - rocket af trial, we investigated the consistency of safety and efficacy profile of rivaroxaban versus warfarin among the subgroups of patients with previous stroke, transient ischemic attack, or non - central nervous system systemic embolism ( secondary prevention group ) and those without ( primary prevention group ). results : patients in the secon

In [12]:
inputs = tokenizer(
    question,
    context,
    max_length=100,
    truncation="only_second",
    stride=50,
    return_overflowing_tokens=True,
)

for ids in inputs["input_ids"]:
    print(tokenizer.decode(ids))

[CLS] what medication were compared in the rocket af trial? [SEP] rivaroxaban versus warfarin in japanese patients with nonvalvular atrial fibrillation for the secondary prevention of stroke : a subgroup analysis of j - rocket af. background : the overall analysis of the rivaroxaban versus warfarin in japanese patients with atrial fibrillation ( j - rocket af ) trial revealed that rivaroxaban was not inferior to warfarin with respect to [SEP]
[CLS] what medication were compared in the rocket af trial? [SEP]. background : the overall analysis of the rivaroxaban versus warfarin in japanese patients with atrial fibrillation ( j - rocket af ) trial revealed that rivaroxaban was not inferior to warfarin with respect to the primary safety outcome. in addition, there was a strong trend for a reduction in the rate of stroke / systemic embolism with rivaroxaban compared with warfarin [SEP]
[CLS] what medication were compared in the rocket af trial? [SEP]oxaban was not inferior to warfarin with 

In [13]:
inputs = tokenizer(
    question,
    context,
    max_length=100,
    truncation="only_second",
    stride=50,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
)
inputs.keys()

dict_keys(['input_ids', 'attention_mask', 'offset_mapping', 'overflow_to_sample_mapping'])

In [14]:
inputs["overflow_to_sample_mapping"]

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

In [15]:
inputs = tokenizer(
    raw_datasets["train"][2:6]["question"],
    raw_datasets["train"][2:6]["context"],
    max_length=100,
    truncation="only_second",
    stride=50,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
)

print(f"The 4 examples gave {len(inputs['input_ids'])} features.")
print(f"Here is where each comes from: {inputs['overflow_to_sample_mapping']}.")

The 4 examples gave 40 features.
Here is where each comes from: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3].


In [16]:
answers = raw_datasets["train"][2:6]["answers"]
start_positions = []
end_positions = []

for i, offset in enumerate(inputs["offset_mapping"]):
    sample_idx = inputs["overflow_to_sample_mapping"][i]
    answer = answers[sample_idx]
    start_char = answer["answer_start"][0]
    end_char = answer["answer_start"][0] + len(answer["text"][0])
    sequence_ids = inputs.sequence_ids(i)

    # Find the start and end of the context
    idx = 0
    while sequence_ids[idx] != 1:
        idx += 1
    context_start = idx
    while sequence_ids[idx] == 1:
        idx += 1
    context_end = idx - 1

    # If the answer is not fully inside the context, label is (0, 0)
    if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
        start_positions.append(0)
        end_positions.append(0)
    else:
        # Otherwise it's the start and end token positions
        idx = context_start
        while idx <= context_end and offset[idx][0] <= start_char:
            idx += 1
        start_positions.append(idx - 1)

        idx = context_end
        while idx >= context_start and offset[idx][1] >= end_char:
            idx -= 1
        end_positions.append(idx + 1)

start_positions, end_positions

([77,
  55,
  33,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  63,
  27,
  0,
  0,
  0,
  0,
  0,
  0,
  20,
  0,
  0,
  0,
  0,
  77,
  42,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 [81,
  59,
  37,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  64,
  28,
  0,
  0,
  0,
  0,
  0,
  0,
  20,
  0,
  0,
  0,
  0,
  81,
  46,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0])

In [17]:
idx = 0
sample_idx = inputs["overflow_to_sample_mapping"][idx]
answer = answers[sample_idx]["text"][0]

start = start_positions[idx]
end = end_positions[idx]
labeled_answer = tokenizer.decode(inputs["input_ids"][idx][start : end + 1])

print(f"Theoretical answer: {answer}, labels give: {labeled_answer}")

Theoretical answer: Campath-1H, labels give: campath - 1h


In [18]:
idx = 4
sample_idx = inputs["overflow_to_sample_mapping"][idx]
answer = answers[sample_idx]["text"][0]

decoded_example = tokenizer.decode(inputs["input_ids"][idx])
print(f"Theoretical answer: {answer}, decoded example: {decoded_example}")

Theoretical answer: Campath-1H, decoded example: [CLS] what are the names of anti - cd52 monoclonal antibody that is used for treatment of multiple sclerosis patients? [SEP] for 18 months after a single pulse of campath - 1h. the first dose of monoclonal antibody was associated with a transient rehearsal of previous symptoms caused by the release of mediators that impede conduction at previously demyelinated sites ; this effect remained despite selective blockade of tumor necrosis factor - alpha. disease activity persisted for several weeks after [SEP]


In [19]:
max_length = 384
stride = 128


def preprocess_training_examples(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=max_length,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    offset_mapping = inputs.pop("offset_mapping")
    sample_map = inputs.pop("overflow_to_sample_mapping")
    answers = examples["answers"]
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        sample_idx = sample_map[i]
        answer = answers[sample_idx]
        start_char = answer["answer_start"][0]
        end_char = answer["answer_start"][0] + len(answer["text"][0])
        sequence_ids = inputs.sequence_ids(i)

        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label is (0, 0)
        if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

In [20]:
train_dataset = raw_datasets["train"].map(
    preprocess_training_examples,
    batched=True,
    remove_columns=raw_datasets["train"].column_names,
)
len(raw_datasets["train"]), len(train_dataset)

Map:   0%|          | 0/6878 [00:00<?, ? examples/s]

(6878, 10378)

## Preprocessing the validation set

In [21]:
def preprocess_validation_examples(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=max_length,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    sample_map = inputs.pop("overflow_to_sample_mapping")
    example_ids = []

    for i in range(len(inputs["input_ids"])):
        sample_idx = sample_map[i]
        example_ids.append(examples["id"][sample_idx])

        sequence_ids = inputs.sequence_ids(i)
        offset = inputs["offset_mapping"][i]
        inputs["offset_mapping"][i] = [
            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
        ]

    inputs["example_id"] = example_ids
    return inputs

In [22]:
validation_dataset = raw_datasets["validation"].map(
    preprocess_validation_examples,
    batched=True,
    remove_columns=raw_datasets["validation"].column_names,
)

test_dataset = raw_datasets["test"].map(
    preprocess_validation_examples,
    batched=True,
    remove_columns=raw_datasets["test"].column_names,
)
print(len(raw_datasets["validation"]), len(validation_dataset))
print(len(raw_datasets["test"]), len(test_dataset))
eval_set = validation_dataset

Map:   0%|          | 0/859 [00:00<?, ? examples/s]

Map:   0%|          | 0/861 [00:00<?, ? examples/s]

859 1304
861 1308


# Fine-tuning the model with the Trainer API

## Post processing

In [23]:
small_eval_set = raw_datasets["test"]
model_checkpoint = "distilbert-base-uncased"
trained_checkpoint = "distilbert-base-cased-distilled-squad"

tokenizer = AutoTokenizer.from_pretrained(trained_checkpoint)
eval_set = small_eval_set.map(
    preprocess_validation_examples,
    batched=True,
    remove_columns=raw_datasets["test"].column_names,
)

Map:   0%|          | 0/861 [00:00<?, ? examples/s]

In [24]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [25]:
# import torch
# from transformers import AutoModelForQuestionAnswering

# eval_set_for_model = eval_set.remove_columns(["example_id", "offset_mapping"])
# eval_set_for_model.set_format("torch")

# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# batch = {k: eval_set_for_model[k].to(device) for k in eval_set_for_model.column_names}
# trained_model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint).to(device)

# with torch.no_grad():
#     outputs = trained_model(**batch)

In [26]:
# start_logits = outputs.start_logits.cpu().numpy()
# end_logits = outputs.end_logits.cpu().numpy()

In [27]:
# import collections

# example_to_features = collections.defaultdict(list)
# for idx, feature in enumerate(eval_set):
#     example_to_features[feature["example_id"]].append(idx)

In [45]:
import numpy as np

n_best = 20
max_answer_length = 30
# predicted_answers = []

# for example in small_eval_set:
#     example_id = example["id"]
#     context = example["context"]
#     answers = []

#     for feature_index in example_to_features[example_id]:
#         start_logit = start_logits[feature_index]
#         end_logit = end_logits[feature_index]
#         offsets = eval_set["offset_mapping"][feature_index]

#         start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
#         end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
#         for start_index in start_indexes:
#             for end_index in end_indexes:
#                 # Skip answers that are not fully in the context
#                 if offsets[start_index] is None or offsets[end_index] is None:
#                     continue
#                 # Skip answers with a length that is either < 0 or > max_answer_length.
#                 if (
#                     end_index < start_index
#                     or end_index - start_index + 1 > max_answer_length
#                 ):
#                     continue

#                 answers.append(
#                     {
#                         "text": context[offsets[start_index][0] : offsets[end_index][1]],
#                         "logit_score": start_logit[start_index] + end_logit[end_index],
#                     }
#                 )

#     best_answer = max(answers, key=lambda x: x["logit_score"])
#     predicted_answers.append({"id": example_id, "prediction_text": best_answer["text"]})

In [29]:
# from datasets import load_metric

# metric = load_metric("squad")

  metric = load_metric("squad")


In [30]:
# theoretical_answers = [
#     {"id": ex["id"], "answers": ex["answers"]} for ex in small_eval_set
# ]

In [31]:
# print(predicted_answers[0])
# print(theoretical_answers[0])

{'id': '511a4d391159fa8212000003_040', 'prediction_text': 'transferases, and'}
{'id': '511a4d391159fa8212000003_040', 'answers': {'answer_start': [931], 'text': ['met']}}


In [31]:
# metric.compute(predictions=predicted_answers, references=theoretical_answers)

In [32]:
from tqdm.auto import tqdm
from datasets import load_metric
import collections
import torch
from transformers import AutoModelForQuestionAnswering

def compute_metrics(start_logits, end_logits, features, examples):
    example_to_features = collections.defaultdict(list)
    for idx, feature in enumerate(features):
        example_to_features[feature["example_id"]].append(idx)

    predicted_answers = []
    for example in tqdm(examples):
        example_id = example["id"]
        context = example["context"]
        answers = []

        # Loop through all features associated with that example
        for feature_index in example_to_features[example_id]:
            start_logit = start_logits[feature_index]
            end_logit = end_logits[feature_index]
            offsets = features[feature_index]["offset_mapping"]

            start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
            end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Skip answers that are not fully in the context
                    if offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    # Skip answers with a length that is either < 0 or > max_answer_length
                    if (
                        end_index < start_index
                        or end_index - start_index + 1 > max_answer_length
                    ):
                        continue

                    answer = {
                        "text": context[offsets[start_index][0] : offsets[end_index][1]],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    }
                    answers.append(answer)

        # Select the answer with the best score
        if len(answers) > 0:
            best_answer = max(answers, key=lambda x: x["logit_score"])
            predicted_answers.append(
                {"id": example_id, "prediction_text": best_answer["text"]}
            )
        else:
            predicted_answers.append({"id": example_id, "prediction_text": ""})

    theoretical_answers = [{"id": ex["id"], "answers": ex["answers"]} for ex in examples]
    return metric.compute(predictions=predicted_answers, references=theoretical_answers)

In [33]:
# compute_metrics(start_logits, end_logits, eval_set, small_eval_set)

## Fine tune the model

In [45]:
# model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForQuestionAnswering: ['vocab_projector.bias', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_projector.weight', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this mode

In [34]:
# from transformers import TrainingArguments

# args = TrainingArguments(
#     "bert-finetuned-squad",
#     evaluation_strategy="no",
#     save_strategy="epoch",
#     learning_rate=2e-5,
#     num_train_epochs=3,
#     weight_decay=0.01,
#     fp16=True,
#     # push_to_hub=True,
# )

In [35]:
# from transformers import Trainer

# trainer = Trainer(
#     model=model,
#     args=args,
#     train_dataset=train_dataset,
#     eval_dataset=validation_dataset,
#     tokenizer=tokenizer,
# )
# trainer.train()

In [36]:
# predictions, _, _ = trainer.predict(validation_dataset)
# start_logits, end_logits = predictions
# compute_metrics(start_logits, end_logits, validation_dataset, raw_datasets["test"])

# Torch Training Loop

In [37]:
from torch.utils.data import DataLoader
from transformers import default_data_collator

train_dataset.set_format("torch")
validation_set = validation_dataset.remove_columns(["example_id", "offset_mapping"])
validation_set.set_format("torch")
test_set = test_dataset.remove_columns(["example_id", "offset_mapping"])
test_set.set_format("torch")

train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    collate_fn=default_data_collator,
    batch_size=8,
)
eval_dataloader = DataLoader(
    validation_set, collate_fn=default_data_collator, batch_size=8
)
test_dataloader = DataLoader(
    test_set, collate_fn=default_data_collator, batch_size=8
)

In [38]:
# model_checkpoint = 'sultan/BioM-ELECTRA-Large-SQuAD2-BioASQ8B'
model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForQuestionAnswering: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this mode

In [39]:
from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=2e-5)

In [40]:
from accelerate import Accelerator

accelerator = Accelerator()
model, optimizer, train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader, test_dataloader
)

In [41]:
from transformers import get_scheduler

num_train_epochs = 2
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

In [47]:
from tqdm.auto import tqdm
import torch
import numpy as np

output_dir = "bert-finetuned-squad-accelerate"
progress_bar = tqdm(range(num_training_steps))

for epoch in range(num_train_epochs):
    # Training
    model.train()
    for step, batch in enumerate(train_dataloader):
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

    # Evaluation
    model.eval()
    start_logits = []
    end_logits = []
    accelerator.print("Evaluation!")
    for batch in tqdm(eval_dataloader):
        with torch.no_grad():
            outputs = model(**batch)

        start_logits.append(accelerator.gather(outputs.start_logits).cpu().numpy())
        end_logits.append(accelerator.gather(outputs.end_logits).cpu().numpy())

    start_logits = np.concatenate(start_logits)
    end_logits = np.concatenate(end_logits)
    start_logits = start_logits[: len(validation_dataset)]
    end_logits = end_logits[: len(validation_dataset)]

    metrics = compute_metrics(
        start_logits, end_logits, validation_dataset, raw_datasets["validation"]
    )
    print(f"epoch {epoch}:", metrics)

    # Save and upload
    # accelerator.wait_for_everyone()
    # unwrapped_model = accelerator.unwrap_model(model)
    # unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
    # if accelerator.is_main_process:
    #     tokenizer.save_pretrained(output_dir)
        # repo.push_to_hub(
        #     commit_message=f"Training in progress epoch {epoch}", blocking=False
        # )

  0%|          | 0/2596 [00:00<?, ?it/s]

Evaluation!


  0%|          | 0/163 [00:00<?, ?it/s]

  0%|          | 0/859 [00:00<?, ?it/s]

epoch 0: {'exact_match': 30.384167636786962, 'f1': 37.17356400273935}
Evaluation!


  0%|          | 0/163 [00:00<?, ?it/s]

  0%|          | 0/859 [00:00<?, ?it/s]

epoch 1: {'exact_match': 30.384167636786962, 'f1': 37.17356400273935}


In [48]:
# Evaluation
model.eval()
start_logits = []
end_logits = []
accelerator.print("Evaluation on testset!")
for batch in tqdm(test_dataloader):
    with torch.no_grad():
        outputs = model(**batch)

    start_logits.append(accelerator.gather(outputs.start_logits).cpu().numpy())
    end_logits.append(accelerator.gather(outputs.end_logits).cpu().numpy())

start_logits = np.concatenate(start_logits)
end_logits = np.concatenate(end_logits)
start_logits = start_logits[: len(test_dataset)]
end_logits = end_logits[: len(test_dataset)]

metrics = compute_metrics(
    start_logits, end_logits, test_dataset, raw_datasets["test"]
)
print(f"Test metrics:", metrics)

Evaluation on testset!


  0%|          | 0/164 [00:00<?, ?it/s]

  0%|          | 0/861 [00:00<?, ?it/s]

Test metrics: {'exact_match': 33.44947735191638, 'f1': 40.57459053743507}


In [58]:
# accelerator.wait_for_everyone()
# unwrapped_model = accelerator.unwrap_model(model)
# unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)

Configuration saved in bert-finetuned-squad-accelerate/config.json
Model weights saved in bert-finetuned-squad-accelerate/pytorch_model.bin


# Use the pretrained model to inference

In [58]:
# from transformers import pipeline

# # Replace this with your own checkpoint
# model_checkpoint = "huggingface-course/bert-finetuned-squad"
# question_answerer = pipeline("question-answering", model=model_checkpoint)

# context = """
# 🤗 Transformers is backed by the three most popular deep learning libraries — Jax, PyTorch and TensorFlow — with a seamless integration
# between them. It's straightforward to train your models with one before loading them for inference with the other.
# """
# question = "Which deep learning libraries back 🤗 Transformers?"
# question_answerer(question=question, context=context)

loading configuration file config.json from cache at /home/aoyuli/.cache/huggingface/hub/models--huggingface-course--bert-finetuned-squad/snapshots/cdce6f8f43121716ec99d2d2a28ff06ddbefa2e0/config.json
Model config BertConfig {
  "_name_or_path": "huggingface-course/bert-finetuned-squad",
  "architectures": [
    "BertForQuestionAnswering"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "torch_dtype": "float32",
  "transformers_version": "4.25.1",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 28996
}

loading configuration file config.json from cache at /home/aoyuli/.cache/huggin

{'score': 0.9978997707366943,
 'start': 78,
 'end': 105,
 'answer': 'Jax, PyTorch and TensorFlow'}