In [None]:
'''extractive QA
BERT - squad2.0'''

In [1]:
!pip install transformers -q

In [2]:
from tqdm import tqdm

from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import os

if not os.path.exists('/content/drive/MyDrive/Projects/BERT QA'):
    os.mkdir('/content/drive/MyDrive/Projects/BERT QA')

In [4]:
!wget -nc https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json
!wget -nc https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json

--2024-02-13 01:03:43--  https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json
Resolving rajpurkar.github.io (rajpurkar.github.io)... 185.199.109.153, 185.199.110.153, 185.199.111.153, ...
Connecting to rajpurkar.github.io (rajpurkar.github.io)|185.199.109.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 42123633 (40M) [application/json]
Saving to: ‘train-v2.0.json’


2024-02-13 01:03:46 (366 MB/s) - ‘train-v2.0.json’ saved [42123633/42123633]

--2024-02-13 01:03:46--  https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json
Resolving rajpurkar.github.io (rajpurkar.github.io)... 185.199.109.153, 185.199.110.153, 185.199.111.153, ...
Connecting to rajpurkar.github.io (rajpurkar.github.io)|185.199.109.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4370528 (4.2M) [application/json]
Saving to: ‘dev-v2.0.json’


2024-02-13 01:03:47 (283 MB/s) - ‘dev-v2.0.json’ saved [4370528/4370528]



In [5]:
import torch
import json
import requests
from torch.utils.data import DataLoader
from transformers import BertTokenizerFast, BertForQuestionAnswering


In [6]:
with open('train-v2.0.json', 'rb') as f:
    squad = json.load(f)

In [None]:
#squad['data'][0].keys() -> dict_keys(['title', 'paragraphs']), squad['data'][0], squad['data'][0]['paragraphs'][0]['context']

In [7]:
def read_data(path):
    """
    Read SQuAD data from a JSON file.

    Parameters:
    - path: Path to the JSON file containing SQuAD data

    Returns:
    - contexts: List of contexts (passages)
    - questions: List of questions
    - answers: List of answers
    """
    # Open the JSON file and load the data
    with open(path, 'r', encoding='utf-8') as f:
        squad = json.load(f)

    # Initialize lists to store contexts, questions, and answers
    contexts = []
    questions = []
    answers = []

    # Iterate over groups in the SQuAD data
    for group in squad.get('data', []):
        # Iterate over paragraphs in the group
        for passage in group.get('paragraphs', []):
            # Get the context (passage)
            context = passage.get('context', '')
            # Iterate over questions and answers in the paragraph
            for qa in passage.get('qas', []):
                # Get the question
                question = qa.get('question', '')
                # Iterate over answers for the question
                for answer in qa.get('answers', []):
                    # Append context, question, and answer to their respective lists
                    contexts.append(context)
                    questions.append(question)
                    answers.append(answer)

    # Return the lists of contexts, questions, and answers
    return contexts, questions, answers

# Read training data
train_contexts, train_questions, train_answers = read_data('train-v2.0.json')
# Read validation data
valid_contexts, valid_questions, valid_answers = read_data('dev-v2.0.json')

In [8]:
def add_end_index(answers, contexts):
    for answer, context in zip(answers, contexts):
        gold_text = answer['text']
        start_idx = answer['answer_start']
        end_idx = start_idx + len(gold_text)

        # Check if the answer is correctly positioned
        for offset in [0, -1, -2]:
            if context[start_idx + offset:end_idx + offset] == gold_text:
                # Update answer start and end indices
                answer['answer_start'] = start_idx + offset
                answer['answer_end'] = end_idx + offset
                break  # Break loop once correct offset is found

add_end_index(train_answers, train_contexts)
add_end_index(valid_answers, valid_contexts)

In [9]:
# Initialize tokenizer
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
train_encodings = tokenizer(train_contexts, train_questions, truncation=True, padding=True)
valid_encodings = tokenizer(valid_contexts, valid_questions, truncation=True, padding=True)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [None]:
# # Initialize tokenizer
# tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

# # Check if saved encodings exist
# if os.path.exists("train_encodings.pt") and os.path.exists("valid_encodings.pt"):
#     # Load saved encodings
#     train_encodings = torch.load("train_encodings.pt")
#     valid_encodings = torch.load("valid_encodings.pt")
# else:
#     # Generate encodings and save them
#     train_encodings = tokenizer(train_contexts, train_questions, truncation=True, padding=True)
#     valid_encodings = tokenizer(valid_contexts, valid_questions, truncation=True, padding=True)
#     torch.save(train_encodings, "train_encodings.pt")
#     torch.save(valid_encodings, "valid_encodings.pt")

In [None]:
'''
train_encodings.keys()
#dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])
no_of_encodings = len(train_encodings['input_ids'])
print(f'{no_of_encodings} context-question pairs count')
tokenizer.decode(train_encodings['input_ids'][0])
'''

In [11]:
def add_token_positions(encodings, answers):
    """
    Adds token positions for answers to encodings.

    Parameters:
    - encodings: Encodings object containing tokenized inputs
    - answers: List of dictionaries containing answer positions

    Returns:
    None (modifies encodings in place)
    """
    start_positions = []
    end_positions = []

    # Loop through each answer
    for i, answer in enumerate(answers):
        # Convert character positions to token positions
        start_positions.append(encodings.char_to_token(i, answer['answer_start']))
        end_positions.append(encodings.char_to_token(i, answer['answer_end'] - 1))

        # Handle cases where answer passage has been truncated
        if start_positions[-1] is None:
            start_positions[-1] = tokenizer.model_max_length
        if end_positions[-1] is None:
            end_positions[-1] = tokenizer.model_max_length

    # Update encodings with start and end positions
    encodings.update({'start_positions': start_positions, 'end_positions': end_positions})

# Add token positions for training data
add_token_positions(train_encodings, train_answers)
# Add token positions for validation data
add_token_positions(valid_encodings, valid_answers)

In [12]:
class SQuAD_Dataset(torch.utils.data.Dataset):
    """
    Custom dataset class for SQuAD.

    Parameters:
    - encodings: Encodings object containing tokenized inputs
    """
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        """
        Retrieves an item from the dataset.

        Parameters:
        - idx: Index of the item to retrieve

        Returns:
        Dictionary containing tensors for each key in the encodings
        """
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        """
        Returns the length of the dataset.

        Returns:
        Integer representing the length of the dataset
        """
        return len(self.encodings.input_ids)

# Create training dataset
train_dataset = SQuAD_Dataset(train_encodings)
# Create validation dataset
valid_dataset = SQuAD_Dataset(valid_encodings)

In [13]:
# Define the dataloaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=16)

In [14]:
model = BertForQuestionAnswering.from_pretrained("bert-base-uncased")

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [15]:
# Check the available device and use GPU if available, otherwise use CPU
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# Print the device being used
print(f'Working on {device}')

Working on cuda


In [None]:
# Number of epochs for training: 3-9
N_EPOCHS = 3

# Optimizer definition
optim = torch.optim.AdamW(model.parameters(), lr=5e-5)

# Move model to the appropriate device (GPU if available, otherwise CPU)
model.to(device)
# Set model in training mode
model.train()

# Iterate over epochs
for epoch in range(N_EPOCHS):
    # Create a progress bar for the training data
    loop = tqdm(train_loader, leave=True)
    # Iterate over batches in the training data
    for batch in loop:
        # Zero gradients from previous iteration
        optim.zero_grad()
        # Move input tensors to the appropriate device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_positions = batch['start_positions'].to(device)
        end_positions = batch['end_positions'].to(device)
        # Forward pass through the model
        outputs = model(input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions)
        # Compute the loss
        loss = outputs[0]
        # Backpropagation: compute gradients
        loss.backward()
        # Update model parameters
        optim.step()

        # Update progress bar description with current epoch
        loop.set_description(f'Epoch {epoch+1}')
        # Update progress bar with current loss
        loop.set_postfix(loss=loss.item())

# Define the path where the model and tokenizer will be saved
model_path = '/content/drive/MyDrive/Projects/BERT QA'

# Save the model's weights, configuration, and vocabulary to the specified path
model.save_pretrained(model_path)

# Save the tokenizer's vocabulary and tokenizer configuration to the specified path
tokenizer.save_pretrained(model_path)

In [None]:
# # Define the path where the pre-trained model and tokenizer are saved
# model_path = '/content/drive/MyDrive/Projects/BERT QA'

# # Load the pre-trained BERT model from the specified path
# model = BertForQuestionAnswering.from_pretrained(model_path)

# # Load the tokenizer from the specified path
# tokenizer = BertTokenizerFast.from_pretrained(model_path)

# # Check the available device and use GPU if available, otherwise use CPU
# device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# # Move the model to the appropriate device
# model = model.to(device)

# # Print the device being used
# print(f'Working on {device}')

In [None]:
# Set the model to evaluation mode
model.eval()

# Initialize a list to store accuracy values
acc = []

# Iterate over batches in the validation data
for batch in tqdm(valid_loader):
    with torch.no_grad():
        # Move input tensors to the appropriate device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_true = batch['start_positions'].to(device)
        end_true = batch['end_positions'].to(device)

        # Forward pass through the model
        outputs = model(input_ids, attention_mask=attention_mask)

        # Get predicted start and end positions
        start_pred = torch.argmax(outputs['start_logits'], dim=1)
        end_pred = torch.argmax(outputs['end_logits'], dim=1)

        # Compute accuracy for start positions and end positions
        acc.append(((start_pred == start_true).sum() / len(start_pred)).item())
        acc.append(((end_pred == end_true).sum() / len(end_pred)).item())

# Compute the average accuracy
acc = sum(acc) / len(acc)

# Print the header for true and predicted answer positions
print("\n\nT/P\tanswer_start\tanswer_end\n")

# Print true and predicted start and end positions for each example
for i in range(len(start_true)):
    print(f"true\t{start_true[i]}\t{end_true[i]}\n"
          f"pred\t{start_pred[i]}\t{end_pred[i]}\n")

In [None]:
def get_prediction(context, question):
    """
    Get the predicted answer for a given context and question.

    Parameters:
    - context: The context in which the question is asked
    - question: The question to be answered

    Returns:
    - answer: The predicted answer to the question
    """
    # Tokenize the question and context
    inputs = tokenizer.encode_plus(question, context, return_tensors='pt').to(device)
    # Perform inference using the model
    outputs = model(**inputs)

    # Get the predicted start and end positions
    answer_start = torch.argmax(outputs[0])
    answer_end = torch.argmax(outputs[1]) + 1

    # Convert the predicted token IDs to string
    answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))

    return answer

def normalize_text(s):
    """
    Normalize text by removing articles, punctuation, and standardizing whitespace.

    Parameters:
    - s: Input text to be normalized

    Returns:
    - Normalized text
    """
    import string, re

    # Function to remove articles from text
    def remove_articles(text):
        regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
        return re.sub(regex, " ", text)

    # Function to fix white space in text
    def white_space_fix(text):
        return " ".join(text.split())

    # Function to remove punctuation from text
    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    # Function to convert text to lowercase
    def lower(text):
        return text.lower()

    # Apply text normalization steps
    return white_space_fix(remove_articles(remove_punc(lower(s))))

def exact_match(prediction, truth):
    """
    Compute exact match between predicted answer and true answer.

    Parameters:
    - prediction: Predicted answer
    - truth: True answer

    Returns:
    - Boolean indicating whether the prediction exactly matches the truth
    """
    return bool(normalize_text(prediction) == normalize_text(truth))

def compute_f1(prediction, truth):
    """
    Compute F1 score between predicted answer and true answer.

    Parameters:
    - prediction: Predicted answer
    - truth: True answer

    Returns:
    - F1 score
    """
    pred_tokens = normalize_text(prediction).split()
    truth_tokens = normalize_text(truth).split()

    # If either the prediction or the truth is no-answer then F1 score is 1 if they agree, 0 otherwise
    if len(pred_tokens) == 0 or len(truth_tokens) == 0:
        return int(pred_tokens == truth_tokens)

    common_tokens = set(pred_tokens) & set(truth_tokens)

    # If there are no common tokens then F1 score is 0
    if len(common_tokens) == 0:
        return 0

    prec = len(common_tokens) / len(pred_tokens)
    rec = len(common_tokens) / len(truth_tokens)

    return round(2 * (prec * rec) / (prec + rec), 2)

def question_answer(context, question, answer):
    """
    Ask a question given a context and compare the predicted answer to the true answer.

    Parameters:
    - context: The context in which the question is asked
    - question: The question to be answered
    - answer: The true answer to the question

    Returns:
    None (prints the results)
    """
    # Get the predicted answer for the question
    prediction = get_prediction(context, question)
    # Compute exact match score
    em_score = exact_match(prediction, answer)
    # Compute F1 score
    f1_score = compute_f1(prediction, answer)

    # Print the results
    print(f'Question: {question}')
    print(f'Prediction: {prediction}')
    print(f'True Answer: {answer}')
    print(f'Exact match: {em_score}')
    print(f'F1 score: {f1_score}\n')

In [None]:
context = """Albert Einstein was a German-born theoretical physicist who developed the theory of relativity, one of the two pillars of modern physics (alongside quantum mechanics)."""


questions = ["What did Albert Einstein develop?",
             "Where was Albert Einstein born?"]

answers = ["theory of relativity", "german"]

for question, answer in zip(questions, answers):
  question_answer(context, question, answer)

Question: What did Albert Einstein develop?
Prediction: theory of relativity
True Answer: theory of relativity
Exact match: True
F1 score: 1.0

Question: Where was Albert Einstein born?
Prediction: german
True Answer: german
Exact match: True
F1 score: 1.0



In [None]:
# https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/
"""Official evaluation script for SQuAD version 2.0.

In addition to basic functionality, we also compute additional statistics and
plot precision-recall curves if an additional na_prob.json file is provided.
This file is expected to map question ID's to the model's predicted probability
that a question is unanswerable.
"""
'''
import argparse
import collections
import json
import numpy as np
import os
import re
import string
import sys

OPTS = None

def parse_args():
  parser = argparse.ArgumentParser('Official evaluation script for SQuAD version 2.0.')
  parser.add_argument('data_file', metavar='data.json', help='Input data JSON file.')
  parser.add_argument('pred_file', metavar='pred.json', help='Model predictions.')
  parser.add_argument('--out-file', '-o', metavar='eval.json',
                      help='Write accuracy metrics to file (default is stdout).')
  parser.add_argument('--na-prob-file', '-n', metavar='na_prob.json',
                      help='Model estimates of probability of no answer.')
  parser.add_argument('--na-prob-thresh', '-t', type=float, default=1.0,
                      help='Predict "" if no-answer probability exceeds this (default = 1.0).')
  parser.add_argument('--out-image-dir', '-p', metavar='out_images', default=None,
                      help='Save precision-recall curves to directory.')
  parser.add_argument('--verbose', '-v', action='store_true')
  if len(sys.argv) == 1:
    parser.print_help()
    sys.exit(1)
  return parser.parse_args()

def make_qid_to_has_ans(dataset):
  qid_to_has_ans = {}
  for article in dataset:
    for p in article['paragraphs']:
      for qa in p['qas']:
        qid_to_has_ans[qa['id']] = bool(qa['answers'])
  return qid_to_has_ans

def normalize_answer(s):
  """Lower text and remove punctuation, articles and extra whitespace."""
  def remove_articles(text):
    regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
    return re.sub(regex, ' ', text)
  def white_space_fix(text):
    return ' '.join(text.split())
  def remove_punc(text):
    exclude = set(string.punctuation)
    return ''.join(ch for ch in text if ch not in exclude)
  def lower(text):
    return text.lower()
  return white_space_fix(remove_articles(remove_punc(lower(s))))

def get_tokens(s):
  if not s: return []
  return normalize_answer(s).split()

def compute_exact(a_gold, a_pred):
  return int(normalize_answer(a_gold) == normalize_answer(a_pred))

def compute_f1(a_gold, a_pred):
  gold_toks = get_tokens(a_gold)
  pred_toks = get_tokens(a_pred)
  common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
  num_same = sum(common.values())
  if len(gold_toks) == 0 or len(pred_toks) == 0:
    # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
    return int(gold_toks == pred_toks)
  if num_same == 0:
    return 0
  precision = 1.0 * num_same / len(pred_toks)
  recall = 1.0 * num_same / len(gold_toks)
  f1 = (2 * precision * recall) / (precision + recall)
  return f1

def get_raw_scores(dataset, preds):
  exact_scores = {}
  f1_scores = {}
  for article in dataset:
    for p in article['paragraphs']:
      for qa in p['qas']:
        qid = qa['id']
        gold_answers = [a['text'] for a in qa['answers']
                        if normalize_answer(a['text'])]
        if not gold_answers:
          # For unanswerable questions, only correct answer is empty string
          gold_answers = ['']
        if qid not in preds:
          print('Missing prediction for %s' % qid)
          continue
        a_pred = preds[qid]
        # Take max over all gold answers
        exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers)
        f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers)
  return exact_scores, f1_scores

def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
  new_scores = {}
  for qid, s in scores.items():
    pred_na = na_probs[qid] > na_prob_thresh
    if pred_na:
      new_scores[qid] = float(not qid_to_has_ans[qid])
    else:
      new_scores[qid] = s
  return new_scores

def make_eval_dict(exact_scores, f1_scores, qid_list=None):
  if not qid_list:
    total = len(exact_scores)
    return collections.OrderedDict([
        ('exact', 100.0 * sum(exact_scores.values()) / total),
        ('f1', 100.0 * sum(f1_scores.values()) / total),
        ('total', total),
    ])
  else:
    total = len(qid_list)
    return collections.OrderedDict([
        ('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total),
        ('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total),
        ('total', total),
    ])

def merge_eval(main_eval, new_eval, prefix):
  for k in new_eval:
    main_eval['%s_%s' % (prefix, k)] = new_eval[k]

def plot_pr_curve(precisions, recalls, out_image, title):
  plt.step(recalls, precisions, color='b', alpha=0.2, where='post')
  plt.fill_between(recalls, precisions, step='post', alpha=0.2, color='b')
  plt.xlabel('Recall')
  plt.ylabel('Precision')
  plt.xlim([0.0, 1.05])
  plt.ylim([0.0, 1.05])
  plt.title(title)
  plt.savefig(out_image)
  plt.clf()

def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans,
                               out_image=None, title=None):
  qid_list = sorted(na_probs, key=lambda k: na_probs[k])
  true_pos = 0.0
  cur_p = 1.0
  cur_r = 0.0
  precisions = [1.0]
  recalls = [0.0]
  avg_prec = 0.0
  for i, qid in enumerate(qid_list):
    if qid_to_has_ans[qid]:
      true_pos += scores[qid]
    cur_p = true_pos / float(i+1)
    cur_r = true_pos / float(num_true_pos)
    if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i+1]]:
      # i.e., if we can put a threshold after this point
      avg_prec += cur_p * (cur_r - recalls[-1])
      precisions.append(cur_p)
      recalls.append(cur_r)
  if out_image:
    plot_pr_curve(precisions, recalls, out_image, title)
  return {'ap': 100.0 * avg_prec}

def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs,
                                  qid_to_has_ans, out_image_dir):
  if out_image_dir and not os.path.exists(out_image_dir):
    os.makedirs(out_image_dir)
  num_true_pos = sum(1 for v in qid_to_has_ans.values() if v)
  if num_true_pos == 0:
    return
  pr_exact = make_precision_recall_eval(
      exact_raw, na_probs, num_true_pos, qid_to_has_ans,
      out_image=os.path.join(out_image_dir, 'pr_exact.png'),
      title='Precision-Recall curve for Exact Match score')
  pr_f1 = make_precision_recall_eval(
      f1_raw, na_probs, num_true_pos, qid_to_has_ans,
      out_image=os.path.join(out_image_dir, 'pr_f1.png'),
      title='Precision-Recall curve for F1 score')
  oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()}
  pr_oracle = make_precision_recall_eval(
      oracle_scores, na_probs, num_true_pos, qid_to_has_ans,
      out_image=os.path.join(out_image_dir, 'pr_oracle.png'),
      title='Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)')
  merge_eval(main_eval, pr_exact, 'pr_exact')
  merge_eval(main_eval, pr_f1, 'pr_f1')
  merge_eval(main_eval, pr_oracle, 'pr_oracle')

def histogram_na_prob(na_probs, qid_list, image_dir, name):
  if not qid_list:
    return
  x = [na_probs[k] for k in qid_list]
  weights = np.ones_like(x) / float(len(x))
  plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0))
  plt.xlabel('Model probability of no-answer')
  plt.ylabel('Proportion of dataset')
  plt.title('Histogram of no-answer probability: %s' % name)
  plt.savefig(os.path.join(image_dir, 'na_prob_hist_%s.png' % name))
  plt.clf()

def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
  num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
  cur_score = num_no_ans
  best_score = cur_score
  best_thresh = 0.0
  qid_list = sorted(na_probs, key=lambda k: na_probs[k])
  for i, qid in enumerate(qid_list):
    if qid not in scores: continue
    if qid_to_has_ans[qid]:
      diff = scores[qid]
    else:
      if preds[qid]:
        diff = -1
      else:
        diff = 0
    cur_score += diff
    if cur_score > best_score:
      best_score = cur_score
      best_thresh = na_probs[qid]
  return 100.0 * best_score / len(scores), best_thresh

def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
  best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
  best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
  main_eval['best_exact'] = best_exact
  main_eval['best_exact_thresh'] = exact_thresh
  main_eval['best_f1'] = best_f1
  main_eval['best_f1_thresh'] = f1_thresh

def main():
  with open(OPTS.data_file) as f:
    dataset_json = json.load(f)
    dataset = dataset_json['data']
  with open(OPTS.pred_file) as f:
    preds = json.load(f)
  if OPTS.na_prob_file:
    with open(OPTS.na_prob_file) as f:
      na_probs = json.load(f)
  else:
    na_probs = {k: 0.0 for k in preds}
  qid_to_has_ans = make_qid_to_has_ans(dataset)  # maps qid to True/False
  has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
  no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
  exact_raw, f1_raw = get_raw_scores(dataset, preds)
  exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans,
                                        OPTS.na_prob_thresh)
  f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans,
                                     OPTS.na_prob_thresh)
  out_eval = make_eval_dict(exact_thresh, f1_thresh)
  if has_ans_qids:
    has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids)
    merge_eval(out_eval, has_ans_eval, 'HasAns')
  if no_ans_qids:
    no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
    merge_eval(out_eval, no_ans_eval, 'NoAns')
  if OPTS.na_prob_file:
    find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans)
  if OPTS.na_prob_file and OPTS.out_image_dir:
    run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs,
                                  qid_to_has_ans, OPTS.out_image_dir)
    histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, 'hasAns')
    histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, 'noAns')
  if OPTS.out_file:
    with open(OPTS.out_file, 'w') as f:
      json.dump(out_eval, f)
  else:
    print(json.dumps(out_eval, indent=2))

if __name__ == '__main__':
  OPTS = parse_args()
  if OPTS.out_image_dir:
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
  main()
'''

In [None]:
#!pip install datasets -q
#! pip install -U accelerate
#! pip install -U transformers

In [None]:
# # Print the structure of the loaded dataset
# print(squad["train"].features)
'''
{'id': Value(dtype='string', id=None), 'title': Value(dtype='string', id=None), 'context': Value(dtype='string', id=None), 'question': Value(dtype='string', id=None), 'answers': Sequence(feature={'text': Value(dtype='string', id=None), 'answer_start': Value(dtype='int32', id=None)}, length=-1, id=None)}
'''

In [None]:
'''# code version 1
import os
import torch
import pandas as pd
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer, default_data_collator

# Check if GPU is available, else use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load SQuAD dataset
squad = load_dataset("squad", split="train[:5000]") #change
squad = squad.train_test_split(test_size=0.2)

# Convert dataset to DataFrames
df_train = pd.DataFrame.from_dict(squad["train"].to_dict())
df_test = pd.DataFrame.from_dict(squad["test"].to_dict())

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

def preprocess_data(data_frame):
    """
    Preprocesses data for question answering model.

    Args:
        data_frame (pandas.DataFrame): DataFrame containing columns 'question', 'context', and 'answers'.

    Returns:
        datasets.Dataset: Preprocessed dataset containing input_ids, attention_mask,
                          start_positions, and end_positions.
    """
    questions = [q.strip() for q in data_frame["question"]]
    contexts = [c.strip() for c in data_frame["context"]]
    answers = data_frame['answers']

    # Tokenize inputs
    inputs = tokenizer(
        questions,
        contexts,
        max_length=512,
        truncation="only_second",
        return_offsets_mapping=True,
        padding="max_length",
    )

    offset_mappings = inputs.pop("offset_mapping")

    start_positions = []
    end_positions = []

    # Determine start and end positions for answers
    for i, offset_mapping in enumerate(offset_mappings):
        answer = answers[i]
        start_char = answer["answer_start"][0]
        end_char = start_char + len(answer["text"][0])
        sequence_ids = inputs.sequence_ids(i)

        context_start = next(idx for idx, seq_id in enumerate(sequence_ids) if seq_id == 1)
        context_end = next(idx for idx, seq_id in enumerate(sequence_ids[context_start:], start=context_start) if seq_id != 1) - 1

        if offset_mapping[context_start][0] > end_char or offset_mapping[context_end][1] < start_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            start_positions.append(next(idx for idx, offset in enumerate(offset_mapping[context_start:], start=context_start) if offset[0] <= start_char) - 1)
            end_positions.append(next(idx for idx, offset in enumerate(offset_mapping[context_end:], start=context_end) if offset[1] >= end_char) + 1)

    data_frame["start_positions"] = start_positions
    data_frame["end_positions"] = end_positions

    # Create dataset
    data = {
        'input_ids': inputs['input_ids'],
        'attention_mask': inputs['attention_mask'],
        'start_positions': start_positions,
        'end_positions': end_positions,
    }
    data_frame = pd.DataFrame(data)
    dataset = Dataset.from_pandas(data_frame)

    return dataset

# Preprocess train and test datasets
train_dataset = preprocess_data(df_train)
eval_dataset = preprocess_data(df_test)

# Initialize model
model = AutoModelForQuestionAnswering.from_pretrained("distilbert-base-uncased").to(device)

# Initialize data collator
data_collator = default_data_collator

# Training arguments
output_dir = "./fine-tuned-model"
training_args = TrainingArguments(
    output_dir=output_dir,
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    weight_decay=0.01,
    report_to="none",  # Disable logging
    save_strategy="epoch",  # Save checkpoint after each epoch
    save_total_limit=3,  # Save only the last 3 checkpoints
)

# Initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

# Evaluation function
def compute_metrics(p):
    predictions, labels = p
    start_logits, end_logits = predictions
    start_positions, end_positions = labels

    # Decode the predictions
    start_logits = torch.from_numpy(start_logits)
    end_logits = torch.from_numpy(end_logits)

    predictions_start = torch.argmax(start_logits).item()
    predictions_end = torch.argmax(end_logits).item()

    exact_match = ((predictions_start == start_positions) & (predictions_end == end_positions)).sum().item()

    return {"exact_match": exact_match}

trainer.compute_metrics = compute_metrics

# Train the model
trainer.train()

# Save the model after training
model.save_pretrained(output_dir)

# Test user input
# Load the fine-tuned model
model = AutoModelForQuestionAnswering.from_pretrained(output_dir)

# Function to get answer from the model
def get_answer(question, context):
    inputs = tokenizer(
        question,
        context,
        max_length=512,
        truncation="only_second",
        return_offsets_mapping=True,
        padding="max_length",
    )
    with torch.no_grad():
        outputs = model(**inputs)
    answer_start_scores = outputs.start_logits
    answer_end_scores = outputs.end_logits
    answer_start = torch.argmax(answer_start_scores)
    answer_end = torch.argmax(answer_end_scores) + 1
    answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end]))
    return answer

# Test the model interactively
while True:
    # Get user input
    question = input("Enter a question (type 'quit' to exit): ")
    if question.lower() == 'quit':
        break
    context = input("Enter the context: ")

    # Get and print the answer
    answer = get_answer(question, context)
    print("Answer:", answer)
'''