# **Question Answering❓**
with fine-tuned BERT on newsQA.  

Question answering comes in many forms. We’ll look at the particular type of extractive QA that involves answering a question about a passage by highlighting the segment of the passage that answers the question. This involves fine-tuning a model which predicts a start position and an end position in the passage. More specifically, we will fine tune the [bert-base-uncased](https://huggingface.co/bert-base-uncased) model on the [NewsQA](https://huggingface.co/datasets/lucadiliello/newsqa) dataset.

I have followed [this tutorial](https://github.com/angelosps/Question-Answering) from the for how to fine tune BERT on SQuAD 2.0 which in our case is a custom newsQA dataset

In [1]:
!pip install transformers



In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, AdamW
from torch.utils.data import DataLoader, Dataset
import json
from tqdm import tqdm



In [3]:
!pip install datasets



In [4]:
# Load the NewsQA dataset
from datasets import load_dataset
newsqa_dataset = load_dataset('lucadiliello/newsqa')

Downloading and preparing dataset parquet/lucadiliello--newsqa to /root/.cache/huggingface/datasets/parquet/lucadiliello--newsqa-206550e86bcc3ded/0.0.0/0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/29.7M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.63M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Dataset parquet downloaded and prepared to /root/.cache/huggingface/datasets/parquet/lucadiliello--newsqa-206550e86bcc3ded/0.0.0/0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

### **Get data 📁**

Let's extract our data and store them into some data structures.

In [5]:
def read_newsqa_data(dataset):
    contexts = []
    questions = []
    answers = []
    string_ans = []

    for item in dataset:
        context = item['context']
        question = item['question']
        answer = {'answer_start': item['labels'][0]['start'][0], 'answer_end': item['labels'][0]['end'][0]}  # Assuming there's only one answer
        string_answer = item['answers'][0]
        
        contexts.append(context)
        questions.append(question)
        answers.append(answer)
        string_ans.append(string_answer)
    return contexts, questions, answers, string_ans

In [6]:
train_contexts, train_questions, train_answers, train_str_ans = read_newsqa_data(newsqa_dataset['train'].select(list(range(5000))))
valid_contexts, valid_questions, valid_answers, valid_str_ans = read_newsqa_data(newsqa_dataset['validation'].select(list(range(1000))))

In [7]:
train_str_ans[:5]

['19',
 'February.',
 'rape and murder',
 'Moninder Singh Pandher',
 'Moninder Singh Pandher']

### **Tokenization 🔢**

In [8]:
# Initialize the RoBERTa tokenizer
tokenizer = AutoTokenizer.from_pretrained('deepset/roberta-base-squad2')
train_encodings = tokenizer(train_contexts, train_questions, truncation=True, padding=True)
valid_encodings = tokenizer(valid_contexts, valid_questions, truncation=True, padding=True)

Downloading (…)okenizer_config.json:   0%|          | 0.00/79.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/772 [00:00<?, ?B/s]

Next we need to convert our character start/end positions to token start/end positions. Why is that? Because our words converted into tokens, so the answer start/end needs to show the index of start/end token which contains the answer and not the specific characters in the context.

In [9]:
# Convert character start/end positions to token start/end positions
def add_token_positions(encodings, answers):
    start_positions = []
    end_positions = []
    for i in range(len(answers)):
        char_start = answers[i]['answer_start']
        char_end = answers[i]['answer_end']

        token_start = encodings.char_to_token(i, char_start)
        token_end = encodings.char_to_token(i, char_end)

        start_positions.append(token_start)
        end_positions.append(token_end)

        if token_start is None:
            start_positions[-1] = tokenizer.model_max_length
        if token_end is None:
            end_positions[-1] = tokenizer.model_max_length

    encodings.update({'start_positions': start_positions, 'end_positions': end_positions})

In [10]:
add_token_positions(train_encodings, train_answers)

In [11]:
add_token_positions(valid_encodings, valid_answers)

In [12]:
class NewsQA_Dataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

## Creating the dataset using the class

In [13]:
train_dataset = NewsQA_Dataset(train_encodings)
valid_dataset = NewsQA_Dataset(valid_encodings)

In [14]:
# Create dataloaders for training and validation
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=4)

## Importing the model

In [15]:
# Initialize the RoBERTa model for question answering
model = AutoModelForQuestionAnswering.from_pretrained('deepset/roberta-base-squad2')

Downloading model.safetensors:   0%|          | 0.00/496M [00:00<?, ?B/s]

In [16]:
num_layers = model.config.num_hidden_layers
print(f"Number of layers: {num_layers}")

Number of layers: 12


### Fine tuning only the last 3 layers of the model

In [17]:
num_layers_to_freeze = 9
for param in model.roberta.embeddings.parameters():
    param.requires_grad = False
for layer in model.roberta.encoder.layer[:num_layers_to_freeze]:
    for param in layer.parameters():
        param.requires_grad = False

In [18]:
# Check if GPU is available and move the model accordingly
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

RobertaForQuestionAnswering(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (Lay

### Model Hyperparameters

In [19]:
# Initialize the optimizer
optimizer = AdamW(model.parameters(), lr=0.0001)
# Training loop
num_epochs = 100



## Training the Model

In [20]:
model.train()
# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for batch in tqdm(train_loader, desc=f'Epoch {epoch + 1}', dynamic_ncols=True):
        inputs = {key: value.to(device) for key, value in batch.items()}

        # Forward pass
        outputs = model(**inputs)
        loss = outputs.loss
        total_loss += loss.item()
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    # Calculate and print the average loss for this epoch
    avg_loss = total_loss / len(train_loader)
    print(f'Epoch {epoch + 1} - Avg Loss: {avg_loss:.4f}')

Epoch 1: 100%|██████████| 1250/1250 [02:16<00:00,  9.17it/s]


Epoch 1 - Avg Loss: 3.4383


Epoch 2: 100%|██████████| 1250/1250 [02:14<00:00,  9.30it/s]


Epoch 2 - Avg Loss: 3.0022


Epoch 3: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 3 - Avg Loss: 2.7062


Epoch 4: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 4 - Avg Loss: 2.4936


Epoch 5: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 5 - Avg Loss: 2.2517


Epoch 6: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 6 - Avg Loss: 2.0617


Epoch 7: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 7 - Avg Loss: 1.9097


Epoch 8: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 8 - Avg Loss: 1.7577


Epoch 9: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 9 - Avg Loss: 1.6726


Epoch 10: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 10 - Avg Loss: 1.4255


Epoch 11: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 11 - Avg Loss: 1.1824


Epoch 12: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 12 - Avg Loss: 1.0629


Epoch 13: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 13 - Avg Loss: nan


Epoch 14: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 14 - Avg Loss: 0.8650


Epoch 15: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 15 - Avg Loss: 0.8057


Epoch 16: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 16 - Avg Loss: 0.7628


Epoch 17: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 17 - Avg Loss: 0.7043


Epoch 18: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 18 - Avg Loss: 0.6661


Epoch 19: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 19 - Avg Loss: 0.6304


Epoch 20: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 20 - Avg Loss: 0.5955


Epoch 21: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 21 - Avg Loss: 0.5625


Epoch 22: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 22 - Avg Loss: 0.5554


Epoch 23: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 23 - Avg Loss: 0.5267


Epoch 24: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 24 - Avg Loss: 0.4931


Epoch 25: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 25 - Avg Loss: 0.5115


Epoch 26: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 26 - Avg Loss: 0.4578


Epoch 27: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 27 - Avg Loss: 0.4635


Epoch 28: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 28 - Avg Loss: 0.4405


Epoch 29: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 29 - Avg Loss: 0.4231


Epoch 30: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 30 - Avg Loss: nan


Epoch 31: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 31 - Avg Loss: 0.3982


Epoch 32: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 32 - Avg Loss: 0.4161


Epoch 33: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 33 - Avg Loss: 0.3930


Epoch 34: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 34 - Avg Loss: 0.3790


Epoch 35: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 35 - Avg Loss: 0.3695


Epoch 36: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 36 - Avg Loss: 0.3486


Epoch 37: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 37 - Avg Loss: 0.3866


Epoch 38: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 38 - Avg Loss: 0.3426


Epoch 39: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 39 - Avg Loss: nan


Epoch 40: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 40 - Avg Loss: 0.3356


Epoch 41: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 41 - Avg Loss: 0.3317


Epoch 42: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 42 - Avg Loss: 0.3275


Epoch 43: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 43 - Avg Loss: 0.3453


Epoch 44: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 44 - Avg Loss: 0.3174


Epoch 45: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 45 - Avg Loss: 0.3126


Epoch 46: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 46 - Avg Loss: 0.3102


Epoch 47: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 47 - Avg Loss: 0.3164


Epoch 48: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 48 - Avg Loss: 0.2965


Epoch 49: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 49 - Avg Loss: 0.3081


Epoch 50: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 50 - Avg Loss: 0.2842


Epoch 51: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 51 - Avg Loss: 0.2936


Epoch 52: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 52 - Avg Loss: 0.2937


Epoch 53: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 53 - Avg Loss: 0.2806


Epoch 54: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 54 - Avg Loss: 0.2756


Epoch 55: 100%|██████████| 1250/1250 [02:14<00:00,  9.28it/s]


Epoch 55 - Avg Loss: 0.2777


Epoch 56: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 56 - Avg Loss: 0.2650


Epoch 57: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 57 - Avg Loss: 0.2779


Epoch 58: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 58 - Avg Loss: 0.2768


Epoch 59: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 59 - Avg Loss: 0.2572


Epoch 60: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 60 - Avg Loss: 0.2685


Epoch 61: 100%|██████████| 1250/1250 [02:14<00:00,  9.28it/s]


Epoch 61 - Avg Loss: 0.2695


Epoch 62: 100%|██████████| 1250/1250 [02:14<00:00,  9.27it/s]


Epoch 62 - Avg Loss: 0.2614


Epoch 63: 100%|██████████| 1250/1250 [02:14<00:00,  9.27it/s]


Epoch 63 - Avg Loss: 0.2624


Epoch 64: 100%|██████████| 1250/1250 [02:14<00:00,  9.27it/s]


Epoch 64 - Avg Loss: 0.2526


Epoch 65: 100%|██████████| 1250/1250 [02:14<00:00,  9.27it/s]


Epoch 65 - Avg Loss: 0.2542


Epoch 66: 100%|██████████| 1250/1250 [02:14<00:00,  9.26it/s]


Epoch 66 - Avg Loss: 0.2390


Epoch 67: 100%|██████████| 1250/1250 [02:14<00:00,  9.27it/s]


Epoch 67 - Avg Loss: 0.2581


Epoch 68: 100%|██████████| 1250/1250 [02:14<00:00,  9.27it/s]


Epoch 68 - Avg Loss: 0.2558


Epoch 69: 100%|██████████| 1250/1250 [02:14<00:00,  9.27it/s]


Epoch 69 - Avg Loss: 0.2565


Epoch 70: 100%|██████████| 1250/1250 [02:14<00:00,  9.27it/s]


Epoch 70 - Avg Loss: 0.2253


Epoch 71: 100%|██████████| 1250/1250 [02:14<00:00,  9.27it/s]


Epoch 71 - Avg Loss: 0.2368


Epoch 72: 100%|██████████| 1250/1250 [02:14<00:00,  9.27it/s]


Epoch 72 - Avg Loss: 0.2596


Epoch 73: 100%|██████████| 1250/1250 [02:14<00:00,  9.27it/s]


Epoch 73 - Avg Loss: 0.2328


Epoch 74: 100%|██████████| 1250/1250 [02:14<00:00,  9.26it/s]


Epoch 74 - Avg Loss: 0.2433


Epoch 75: 100%|██████████| 1250/1250 [02:14<00:00,  9.27it/s]


Epoch 75 - Avg Loss: 0.2309


Epoch 76: 100%|██████████| 1250/1250 [02:14<00:00,  9.27it/s]


Epoch 76 - Avg Loss: 0.2334


Epoch 77: 100%|██████████| 1250/1250 [02:14<00:00,  9.27it/s]


Epoch 77 - Avg Loss: 0.2276


Epoch 78: 100%|██████████| 1250/1250 [02:14<00:00,  9.27it/s]


Epoch 78 - Avg Loss: 0.2202


Epoch 79: 100%|██████████| 1250/1250 [02:14<00:00,  9.27it/s]


Epoch 79 - Avg Loss: 0.2347


Epoch 80: 100%|██████████| 1250/1250 [02:14<00:00,  9.27it/s]


Epoch 80 - Avg Loss: nan


Epoch 81: 100%|██████████| 1250/1250 [02:14<00:00,  9.27it/s]


Epoch 81 - Avg Loss: 0.2413


Epoch 82: 100%|██████████| 1250/1250 [02:14<00:00,  9.27it/s]


Epoch 82 - Avg Loss: 0.2082


Epoch 83: 100%|██████████| 1250/1250 [02:14<00:00,  9.27it/s]


Epoch 83 - Avg Loss: 0.2238


Epoch 84: 100%|██████████| 1250/1250 [02:14<00:00,  9.28it/s]


Epoch 84 - Avg Loss: 0.2087


Epoch 85: 100%|██████████| 1250/1250 [02:14<00:00,  9.27it/s]


Epoch 85 - Avg Loss: 0.2236


Epoch 86: 100%|██████████| 1250/1250 [02:14<00:00,  9.27it/s]


Epoch 86 - Avg Loss: 0.2228


Epoch 87: 100%|██████████| 1250/1250 [02:14<00:00,  9.28it/s]


Epoch 87 - Avg Loss: 0.2307


Epoch 88: 100%|██████████| 1250/1250 [02:14<00:00,  9.27it/s]


Epoch 88 - Avg Loss: 0.2067


Epoch 89: 100%|██████████| 1250/1250 [02:14<00:00,  9.28it/s]


Epoch 89 - Avg Loss: 0.2122


Epoch 90: 100%|██████████| 1250/1250 [02:14<00:00,  9.28it/s]


Epoch 90 - Avg Loss: 0.2019


Epoch 91: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 91 - Avg Loss: 0.2039


Epoch 92: 100%|██████████| 1250/1250 [02:14<00:00,  9.28it/s]


Epoch 92 - Avg Loss: 0.2045


Epoch 93: 100%|██████████| 1250/1250 [02:14<00:00,  9.28it/s]


Epoch 93 - Avg Loss: 0.2078


Epoch 94: 100%|██████████| 1250/1250 [02:14<00:00,  9.28it/s]


Epoch 94 - Avg Loss: 0.2215


Epoch 95: 100%|██████████| 1250/1250 [02:14<00:00,  9.28it/s]


Epoch 95 - Avg Loss: 0.2069


Epoch 96: 100%|██████████| 1250/1250 [02:14<00:00,  9.29it/s]


Epoch 96 - Avg Loss: 0.1999


Epoch 97: 100%|██████████| 1250/1250 [02:14<00:00,  9.28it/s]


Epoch 97 - Avg Loss: nan


Epoch 98: 100%|██████████| 1250/1250 [02:14<00:00,  9.27it/s]


Epoch 98 - Avg Loss: 0.1961


Epoch 99: 100%|██████████| 1250/1250 [02:14<00:00,  9.28it/s]


Epoch 99 - Avg Loss: 0.2107


Epoch 100: 100%|██████████| 1250/1250 [02:14<00:00,  9.27it/s]

Epoch 100 - Avg Loss: 0.2155





## Saving the Model

In [21]:
# Save the fine-tuned model if needed
model.save_pretrained('local_fine_tuned_roberta_on_newsqa')
tokenizer.save_pretrained('local_fine_tuned_roberta_on_newsqa')

('local_fine_tuned_roberta_on_newsqa/tokenizer_config.json',
 'local_fine_tuned_roberta_on_newsqa/special_tokens_map.json',
 'local_fine_tuned_roberta_on_newsqa/vocab.json',
 'local_fine_tuned_roberta_on_newsqa/merges.txt',
 'local_fine_tuned_roberta_on_newsqa/added_tokens.json',
 'local_fine_tuned_roberta_on_newsqa/tokenizer.json')

In [22]:
 # Initialize the tokenizer and model
fine_tuned_tokenizer = AutoTokenizer.from_pretrained('local_fine_tuned_roberta_on_newsqa')
fine_tuned_model = AutoModelForQuestionAnswering.from_pretrained('local_fine_tuned_roberta_on_newsqa')

In [23]:
fine_tuned_model.to(device)

RobertaForQuestionAnswering(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (Lay

# Inference

In [24]:
# Perform inference
question = "What war was the Iwo Jima battle a part of?"
context = "One of the Marines shown in a famous World War II photograph raising the U.S. flag on Iwo Jima was posthumously awarded a certificate of U.S. citizenship on Tuesday.\n\nThe Marine Corps War Memorial in Virginia depicts Strank and five others raising a flag on Iwo Jima.\n\nSgt. Michael Strank, who was born in Czechoslovakia and came to the United States when he was 3, derived U.S. citizenship when his father was naturalized in 1935. However, U.S. Citizenship and Immigration Services recently discovered that Strank never was given citizenship papers.\n\nAt a ceremony Tuesday at the Marine Corps Memorial -- which depicts the flag-raising -- in Arlington, Virginia, a certificate of citizenship was presented to Strank\'s younger sister, Mary Pero.\n\nStrank and five other men became national icons when an Associated Press photographer captured the image of them planting an American flag on top of Mount Suribachi on February 23, 1945.\n\nStrank was killed in action on the island on March 1, 1945, less than a month before the battle between Japanese and U.S. forces there ended.\n\nJonathan Scharfen, the acting director of CIS, presented the citizenship certificate Tuesday.\n\nHe hailed Strank as a true American hero and a wonderful example of the remarkable contribution and sacrifices that immigrants have made to our great republic throughout its history."

In [25]:
# Tokenize the passage and question
inputs = tokenizer(question, context, return_tensors="pt")
inputs.to(device)

# Perform inference
with torch.no_grad():
    outputs = fine_tuned_model(**inputs)
    start_idx = torch.argmax(outputs[0])
    end_idx = torch.argmax(outputs[1]) + 1

# Get the answer text from the passage
answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][start_idx:end_idx]))

print("Question:", question)
print("Answer:", answer)

Question: What war was the Iwo Jima battle a part of?
Answer:  II


In [26]:
def get_prediction(context, question):
  inputs = tokenizer.encode_plus(question, context, return_tensors='pt').to(device)
  outputs = model(**inputs)

  answer_start = torch.argmax(outputs[0])
  answer_end = torch.argmax(outputs[1]) + 1

  answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))

  return answer

def normalize_text(s):
  """Removing articles and punctuation, and standardizing whitespace are all typical text processing steps."""
  import string, re
  def remove_articles(text):
    regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
    return re.sub(regex, " ", text)
  def white_space_fix(text):
    return " ".join(text.split())
  def remove_punc(text):
    exclude = set(string.punctuation)
    return "".join(ch for ch in text if ch not in exclude)
  def lower(text):
    return text.lower()

  return white_space_fix(remove_articles(remove_punc(lower(s))))

def exact_match(prediction, truth):
    return bool(normalize_text(prediction) == normalize_text(truth))

def compute_f1(prediction, truth):
  pred_tokens = normalize_text(prediction).split()
  truth_tokens = normalize_text(truth).split()

  # if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise
  if len(pred_tokens) == 0 or len(truth_tokens) == 0:
    return int(pred_tokens == truth_tokens)

  common_tokens = set(pred_tokens) & set(truth_tokens)

  # if there are no common tokens then f1 = 0
  if len(common_tokens) == 0:
    return 0

  prec = len(common_tokens) / len(pred_tokens)
  rec = len(common_tokens) / len(truth_tokens)

  return round(2 * (prec * rec) / (prec + rec), 2)

def question_answer(context, question,answer):
  prediction = get_prediction(context,question)
  em_score = exact_match(prediction, answer)
  f1_score = compute_f1(prediction, answer)

  print(f'Question: {question}')
  print(f'Prediction: {prediction}')
  print(f'True Answer: {answer}')
  print(f'Exact match: {em_score}')
  print(f'F1 score: {f1_score}\n')
    
  return f1_score

In [27]:
f1=0
for contexts, question, answer in zip(valid_contexts[:], valid_questions[:], valid_str_ans[:]):
    f1 += question_answer(context, question, answer)
avg_f1_score=f1/1000

Question: What will be nominated?
Prediction:  Strank
True Answer: three different videos
Exact match: False
F1 score: 0

Question: What does the Harrison Ford video feature?
Prediction: 
True Answer: getting his chest waxed,
Exact match: False
F1 score: 0

Question: What videos will you send?
Prediction:  videos
True Answer: environmental
Exact match: False
F1 score: 0

Question: What is Ford getting waxed?
Prediction: 
True Answer: his chest
Exact match: False
F1 score: 0

Question: Who got his chest waxed?
Prediction: One of the Marines
True Answer: Harrison Ford
Exact match: False
F1 score: 0

Question: How do you send in your video?
Prediction:  World War II photograph raising the U.S. flag
True Answer: Use the iReport form
Exact match: False
F1 score: 0

Question: What type of videos should you nominate?
Prediction:  videos
True Answer: think are the best.
Exact match: False
F1 score: 0

Question: What did Steve Bruce describe Amire Zaki as?
Prediction:  true American hero and
Tr

In [28]:
print(f"Average F1 score={avg_f1_score}")

Average F1 score=0.009529999999999995
