<a href="https://colab.research.google.com/github/SRDdev/QABERT-small/blob/master/QuAC_BERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

###⬇️ Imports & Installations

In this section we will import and install all the required libraries.

In [None]:
! pip install transformers datasets

### Load Dataset
We will be using the QuAC dataset from Huggingface Datasets library.[Link](https://huggingface.co/datasets/quac)

In [1]:
from datasets import load_dataset , Dataset
raw_datasets  = load_dataset("quac")



  0%|          | 0/2 [00:00<?, ?it/s]

In [2]:
raw_datasets 

DatasetDict({
    train: Dataset({
        features: ['dialogue_id', 'wikipedia_page_title', 'background', 'section_title', 'context', 'turn_ids', 'questions', 'followups', 'yesnos', 'answers', 'orig_answers'],
        num_rows: 11567
    })
    validation: Dataset({
        features: ['dialogue_id', 'wikipedia_page_title', 'background', 'section_title', 'context', 'turn_ids', 'questions', 'followups', 'yesnos', 'answers', 'orig_answers'],
        num_rows: 1000
    })
})

Lets check what each coolumn(feature) contains

In [4]:
print("id: ", raw_datasets["train"][0]["dialogue_id"][len(raw_datasets["train"][0]["dialogue_id"])-1:])
print("Context: ", raw_datasets["train"][0]["context"])
print("Question: ", raw_datasets["train"][0]["questions"])
print("Answer: ", raw_datasets["train"][0]["answers"]['texts'])

id:  1
Context:  According to the Indian census of 2001, there were 30,803,747 speakers of Malayalam in Kerala, making up 93.2% of the total number of Malayalam speakers in India, and 96.7% of the total population of the state. There were a further 701,673 (2.1% of the total number) in Karnataka, 557,705 (1.7%) in Tamil Nadu and 406,358 (1.2%) in Maharashtra. The number of Malayalam speakers in Lakshadweep is 51,100, which is only 0.15% of the total number, but is as much as about 84% of the population of Lakshadweep. In all, Malayalis made up 3.22% of the total Indian population in 2001. Of the total 33,066,392 Malayalam speakers in India in 2001, 33,015,420 spoke the standard dialects, 19,643 spoke the Yerava dialect and 31,329 spoke non-standard regional variations like Eranadan. As per the 1991 census data, 28.85% of all Malayalam speakers in India spoke a second language and 19.64% of the total knew three or more languages.  Large numbers of Malayalis have settled in Bangalore, Ma

We will now create our own dataset , and remove the unwantted stuff like `"wikipedia_page_title"` etc

In [5]:
# Loop through each example in the 'train' split
question_train = []
answer_train = []
answer_starts_t = []
context_train = []
for example in raw_datasets['train']:
    context_train.append(example['context'])
    answer_train.append(example['answers']['texts'][0])
    answer_starts_t.append(example["answers"]['answer_starts'][0])
    question_train.append(example['questions'][0])
# Loop through each example in the 'val' split
question_val = []
answer_val = []
answer_starts_v=[]
context_val = []
for example in raw_datasets['validation']:
    context_val.append(example['context'])
    answer_val.append(example['answers']['texts'])
    answer_starts_v.append(example["answers"]['answer_starts'][0])
    question_val.append(example['questions'][0])

In [6]:
#ID
id_t = []
id_v = []
for i in range(len(raw_datasets['train'])):
  id_t.append(i)
for i in range(len(raw_datasets['validation'])):
  id_v.append(i)

In [7]:
answer_t = Dataset.from_dict({"texts":answer_train,"answer_starts":answer_starts_t})
answer_v = Dataset.from_dict({"texts":answer_val,"answer_starts":answer_starts_v})

train_data = {"question": question_train,"answer":answer_t,"context": context_train,"id":id_t}
val_data = {"question": question_val,"answer":answer_v,"context": context_val,"id":id_v}

train_dataset = Dataset.from_dict(train_data)
val_dataset = Dataset.from_dict(val_data)

data = {
    "train":train_dataset,
    "val":val_dataset
}
data

{'train': Dataset({
     features: ['question', 'answer', 'context', 'id'],
     num_rows: 11567
 }),
 'val': Dataset({
     features: ['question', 'answer', 'context', 'id'],
     num_rows: 1000
 })}

### ⚒️ Preprocessing
Now that we have our HF dataset we can move to the preprocessing part. We will tokenize the dataset using pretrianed tokenizers from HF. 

The toknizer which we will use is called `BERT-BASE-CASED`

In [8]:
from transformers import AutoTokenizer

model_checkpoint = "bert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [9]:
context = data["train"][0]["context"]
question = data["train"][0]["question"]

inputs = tokenizer(question, context)
tokenizer.decode(inputs["input_ids"])

'[CLS] Where is Malayali located? [SEP] According to the Indian census of 2001, there were 30, 803, 747 speakers of Malayalam in Kerala, making up 93. 2 % of the total number of Malayalam speakers in India, and 96. 7 % of the total population of the state. There were a further 701, 673 ( 2. 1 % of the total number ) in Karnataka, 557, 705 ( 1. 7 % ) in Tamil Nadu and 406, 358 ( 1. 2 % ) in Maharashtra. The number of Malayalam speakers in Lakshadweep is 51, 100, which is only 0. 15 % of the total number, but is as much as about 84 % of the population of Lakshadweep. In all, Malayalis made up 3. 22 % of the total Indian population in 2001. Of the total 33, 066, 392 Malayalam speakers in India in 2001, 33, 015, 420 spoke the standard dialects, 19, 643 spoke the Yerava dialect and 31, 329 spoke non - standard regional variations like Eranadan. As per the 1991 census data, 28. 85 % of all Malayalam speakers in India spoke a second language and 19. 64 % of the total knew three or more langua

In [10]:
inputs = tokenizer(
    question,
    context,
    max_length=100,
    truncation="only_second",
    stride=50,
    return_overflowing_tokens=True,
)

for ids in inputs["input_ids"]:
    print(tokenizer.decode(ids))

[CLS] Where is Malayali located? [SEP] According to the Indian census of 2001, there were 30, 803, 747 speakers of Malayalam in Kerala, making up 93. 2 % of the total number of Malayalam speakers in India, and 96. 7 % of the total population of the state. There were a further 701, 673 ( 2. 1 % of the total number ) in Karnataka, 557, 705 ( 1. 7 % ) in Tamil Nadu and 406 [SEP]
[CLS] Where is Malayali located? [SEP]. 7 % of the total population of the state. There were a further 701, 673 ( 2. 1 % of the total number ) in Karnataka, 557, 705 ( 1. 7 % ) in Tamil Nadu and 406, 358 ( 1. 2 % ) in Maharashtra. The number of Malayalam speakers in Lakshadweep is 51, 100, which is only 0. 15 % of the total number, but [SEP]
[CLS] Where is Malayali located? [SEP] 7 % ) in Tamil Nadu and 406, 358 ( 1. 2 % ) in Maharashtra. The number of Malayalam speakers in Lakshadweep is 51, 100, which is only 0. 15 % of the total number, but is as much as about 84 % of the population of Lakshadweep. In all, Mala

In [11]:
inputs = tokenizer(
    question,
    context,
    max_length=100,
    truncation="only_second",
    stride=50,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
)
inputs.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping', 'overflow_to_sample_mapping'])

In [12]:
inputs = tokenizer(
    data["train"][2:6]["question"],
    data["train"][2:6]["context"],
    max_length=100,
    truncation="only_second",
    stride=50,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
)

print(f"The 4 examples gave {len(inputs['input_ids'])} features.")
print(f"Here is where each comes from: {inputs['overflow_to_sample_mapping']}.")

The 4 examples gave 43 features.
Here is where each comes from: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3].


In [13]:
train_data = data["train"]
answers = train_data[2:6]["answer"]
start_positions = []
end_positions = []

for i, offset in enumerate(inputs["offset_mapping"]):
    sample_idx = inputs["overflow_to_sample_mapping"][i]
    answer = answers[sample_idx]
    start_char = answer["answer_starts"][0]  # Update field name
    end_char = answer["answer_starts"][0] + len(answer["texts"][0])  # Update field name
    sequence_ids = inputs.sequence_ids(i)

    # Find the start and end of the context
    idx = 0
    while sequence_ids[idx] != 1:
        idx += 1
    context_start = idx
    while sequence_ids[idx] == 1:
        idx += 1
    context_end = idx - 1

    # If the answer is not fully inside the context, label is (0, 0)
    if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
        start_positions.append(0)
        end_positions.append(0)
    else:
        # Otherwise it's the start and end token positions
        idx = context_start
        while idx <= context_end and offset[idx][0] <= start_char:
            idx += 1
        start_positions.append(idx - 1)

        idx = context_end
        while idx >= context_start and offset[idx][1] >= end_char:
            idx -= 1
        end_positions.append(idx + 1)

start_positions, end_positions


([25,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  9,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  67,
  28,
  0,
  0,
  0,
  0,
  22,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 [41,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  30,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  74,
  35,
  0,
  0,
  0,
  0,
  47,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0])

In [14]:
idx = 0
sample_idx = inputs["overflow_to_sample_mapping"][idx]
answer = answers[sample_idx]["texts"][0]

start = start_positions[idx]
end = end_positions[idx]
labeled_answer = tokenizer.decode(inputs["input_ids"][idx][start : end + 1])

print(f"Theoretical answer: {answer}, labels give: {labeled_answer}")

Theoretical answer: 19-year-old Cove Reber was announced as their new permanent lead singer., labels give: 19 - year - old Cove Reber was announced as their new permanent lead singer.


In [15]:
idx = 4
sample_idx = inputs["overflow_to_sample_mapping"][idx]
answer = answers[sample_idx]["texts"][0]

decoded_example = tokenizer.decode(inputs["input_ids"][idx])
print(f"Theoretical answer: {answer}, decoded example: {decoded_example}")

Theoretical answer: 19-year-old Cove Reber was announced as their new permanent lead singer., decoded example: [CLS] What do we know about Cove Reber? [SEP] their lives... Saosin is a band on a completely different level. All these dudes are freaks about music. " Reber's addition to the band was difficult, for the more experienced Green was the center piece of the band in the eyes of Saosin's fans. Many fans consider the time with Green to be something entirely different from the time with Reber. There are still distinct fans of both [SEP]


### Preprocessing Training data

In [17]:
max_length = 384
stride = 128

def preprocess_training_examples(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=max_length,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    offset_mapping = inputs.pop("offset_mapping")
    sample_map = inputs.pop("overflow_to_sample_mapping")
    answers = examples["answer"]
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        sample_idx = sample_map[i]
        answer = answers[sample_idx]
        start_char = answer["answer_starts"][0]
        end_char = answer["answer_starts"][0] + len(answer["texts"][0])
        sequence_ids = inputs.sequence_ids(i)

        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label is (0, 0)
        if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs


In [18]:
train_dataset = data["train"].map(
    preprocess_training_examples,
    batched=True,
    remove_columns=data["train"].column_names,
)
len(data["train"]), len(train_dataset)

Map:   0%|          | 0/11567 [00:00<?, ? examples/s]

(11567, 25362)

### Preprocessing Validation Data

In [19]:
def preprocess_validation_examples(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=max_length,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    sample_map = inputs.pop("overflow_to_sample_mapping")
    example_ids = []

    for i in range(len(inputs["input_ids"])):
        sample_idx = sample_map[i]
        example_ids.append(examples["id"][sample_idx])

        sequence_ids = inputs.sequence_ids(i)
        offset = inputs["offset_mapping"][i]
        inputs["offset_mapping"][i] = [
            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
        ]

    inputs["example_id"] = example_ids
    return inputs

In [20]:
validation_dataset = data["val"].map(
    preprocess_validation_examples,
    batched=True,
    remove_columns=data["val"].column_names,
)
len(data["val"]), len(validation_dataset)

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

(1000, 2452)

In [22]:
small_eval_set = data["val"].select(range(100))
trained_checkpoint = "distilbert-base-cased-distilled-squad"

tokenizer = AutoTokenizer.from_pretrained(trained_checkpoint)
eval_set = small_eval_set.map(
    preprocess_validation_examples,
    batched=True,
    remove_columns=data["val"].column_names,
)

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [23]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [24]:
import torch
from transformers import AutoModelForQuestionAnswering

eval_set_for_model = eval_set.remove_columns(["example_id", "offset_mapping"])
eval_set_for_model.set_format("torch")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
batch = {k: eval_set_for_model[k].to(device) for k in eval_set_for_model.column_names}
trained_model = AutoModelForQuestionAnswering.from_pretrained(trained_checkpoint).to(
    device
)

with torch.no_grad():
    outputs = trained_model(**batch)

Downloading pytorch_model.bin:   0%|          | 0.00/261M [00:00<?, ?B/s]

In [25]:
start_logits = outputs.start_logits.cpu().numpy()
end_logits = outputs.end_logits.cpu().numpy()

In [26]:
import collections

example_to_features = collections.defaultdict(list)
for idx, feature in enumerate(eval_set):
    example_to_features[feature["example_id"]].append(idx)

In [27]:
import numpy as np

n_best = 20
max_answer_length = 30
predicted_answers = []

for example in small_eval_set:
    example_id = example["id"]
    context = example["context"]
    answers = []

    for feature_index in example_to_features[example_id]:
        start_logit = start_logits[feature_index]
        end_logit = end_logits[feature_index]
        offsets = eval_set["offset_mapping"][feature_index]

        start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
        end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
        for start_index in start_indexes:
            for end_index in end_indexes:
                # Skip answers that are not fully in the context
                if offsets[start_index] is None or offsets[end_index] is None:
                    continue
                # Skip answers with a length that is either < 0 or > max_answer_length.
                if (
                    end_index < start_index
                    or end_index - start_index + 1 > max_answer_length
                ):
                    continue

                answers.append(
                    {
                        "text": context[offsets[start_index][0] : offsets[end_index][1]],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    }
                )

    best_answer = max(answers, key=lambda x: x["logit_score"])
    predicted_answers.append({"id": example_id, "prediction_text": best_answer["text"]})

### Fine Tune `BERT` 

In [29]:
model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

Downloading pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForQuestionAnswering: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-cased and a

In [30]:
from huggingface_hub import notebook_login

notebook_login()

Token is valid.
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [31]:
from transformers import TrainingArguments

args = TrainingArguments(
    "QuAC-QA-BERT",
    evaluation_strategy="no",
    save_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=3,
    weight_decay=0.01,
    fp16=True,
    push_to_hub=True,
)

In [32]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=validation_dataset,
    tokenizer=tokenizer,
)
trainer.train()

Cloning https://huggingface.co/SRDdev/QuAC-QA-BERT into local empty directory.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss
500,2.213
1000,1.8118
1500,1.7026
2000,1.656
2500,1.5926
3000,1.5578
3500,1.3543
4000,1.3224
4500,1.2725
5000,1.3272


TrainOutput(global_step=9513, training_loss=1.337987121642983, metrics={'train_runtime': 2416.4456, 'train_samples_per_second': 31.487, 'train_steps_per_second': 3.937, 'total_flos': 1.4910768774761472e+16, 'train_loss': 1.337987121642983, 'epoch': 3.0})

In [38]:
compute_metrics= "squad"
predictions, _, _ = trainer.predict(validation_dataset)
start_logits, end_logits = predictions

In [34]:
trainer.push_to_hub(commit_message="Training complete")

To https://huggingface.co/SRDdev/QuAC-QA-BERT
   8dff18b..a1ffe17  main -> main

   8dff18b..a1ffe17  main -> main

To https://huggingface.co/SRDdev/QuAC-QA-BERT
   a1ffe17..5559501  main -> main

   a1ffe17..5559501  main -> main



'https://huggingface.co/SRDdev/QuAC-QA-BERT/commit/a1ffe17a71c8dbf5bdb0b41fd7af55e63c3fee47'

### Inference 🤗

In [49]:
from transformers import AutoTokenizer, AutoModelForQuestionAnswering

tokenizer = AutoTokenizer.from_pretrained("SRDdev/QuAC-QA-BERT")
model = AutoModelForQuestionAnswering.from_pretrained("SRDdev/QuAC-QA-BERT")

In [50]:
context = """My name is Sarah and I live in London"""

In [51]:
from transformers import pipeline

ask = pipeline("question-answering", model= model , tokenizer = tokenizer)
result = ask(question="Where do I live?", context=context)
print(f"Answer: '{result['answer']}'")

Answer: 'My name is Sarah and I live in London'
