<a href="https://colab.research.google.com/github/vasudevgupta7/bigbird/blob/main/notebooks/evaluate_nq.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!nvidia-smi

Wed May 12 01:42:38 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   45C    P8    10W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
%%capture
!git clone https://github.com/vasudevgupta7/bigbird
!cd bigbird/natural-questions && pip3 install -r requirements.txt

In [3]:
cd bigbird/natural-questions

/content/bigbird/natural-questions


In [4]:
!mkdir natural-questions-validation
!wget https://huggingface.co/datasets/vasudevgupta/natural-questions-validation/resolve/main/natural_questions-validation.arrow -P natural-questions-validation
!wget https://huggingface.co/datasets/vasudevgupta/natural-questions-validation/resolve/main/dataset_info.json -P natural-questions-validation
!wget https://huggingface.co/datasets/vasudevgupta/natural-questions-validation/resolve/main/state.json -P natural-questions-validation

--2021-05-12 01:43:14--  https://huggingface.co/datasets/vasudevgupta/natural-questions-validation/resolve/main/natural_questions-validation.arrow
Resolving huggingface.co (huggingface.co)... 34.201.172.85
Connecting to huggingface.co (huggingface.co)|34.201.172.85|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/datasets/vasudevgupta/natural-questions-validation/80201e1a5434183f43f73f1b7f57dfe9ea6c50da4f45007aeb192a2ec6556e70 [following]
--2021-05-12 01:43:14--  https://cdn-lfs.huggingface.co/datasets/vasudevgupta/natural-questions-validation/80201e1a5434183f43f73f1b7f57dfe9ea6c50da4f45007aeb192a2ec6556e70
Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 52.84.169.23, 52.84.169.48, 52.84.169.13, ...
Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|52.84.169.23|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2353975312 (2.2G) [application/octet-stream]
Saving to: ‘natural-que

In [42]:
from datasets import load_dataset, load_from_disk

dataset = load_from_disk("natural-questions-validation")
dataset

Dataset({
    features: ['annotations', 'document', 'id', 'question'],
    num_rows: 7830
})

In [43]:
def format_dataset(sample):
  question = sample['question']['text']
  context = sample['document']['tokens']['token']
  is_html = sample['document']['tokens']['is_html']
  long_answers = sample['annotations']['long_answer']
  short_answers = sample['annotations']['short_answers']

  context_string =  " ".join([context[i] for i in range(len(context)) if not is_html[i]])

  # 0 - No ; 1 - Yes
  for answer in sample['annotations']['yes_no_answer']:
    if answer == 0 or answer == 1:
      return {"question": question, "context": context_string, "short": [], "long": [], "category": "no" if answer == 0 else "yes"}

  short_targets = []
  for s in short_answers:
    short_targets.extend(s['text'])
  short_targets = list(set(short_targets))

  long_targets = []
  for s in long_answers:
    if s['start_token'] == -1:
      continue
    answer = context[s['start_token']: s['end_token']]
    html = is_html[s['start_token']: s['end_token']]
    new_answer = " ".join([answer[i] for i in range(len(answer)) if not html[i]])
    if new_answer not in long_targets:
      long_targets.append(new_answer)

  category = "long_short" if len(short_targets + long_targets) > 0 else "null"

  return {"question": question, "context": context_string, "short": short_targets, "long": long_targets, "category": category}

In [44]:
dataset = dataset.map(format_dataset).remove_columns(["annotations", "document", "id"])

Loading cached processed dataset at natural-questions-validation/cache-28cfe42f021c8118.arrow


In [45]:
short_validation_dataset = dataset.filter(lambda x: (len(x['question']) + len(x['context'])) < 4 * 4096)
short_validation_dataset = short_validation_dataset.filter(lambda x: x['category'] != "null")
short_validation_dataset

Loading cached processed dataset at natural-questions-validation/cache-65dc9965e3ec71e4.arrow


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))




Dataset({
    features: ['category', 'context', 'long', 'question', 'short'],
    num_rows: 1823
})

In [46]:
PUNCTUATION_SET_TO_EXCLUDE = set(''.join(['‘', '’', '´', '`', '.', ',', '-', '"']))

def get_sub_answers(answers, begin=0, end=None):
  return [" ".join(x.split(" ")[begin:end]) for x in answers if len(x.split(" ")) > 1]

def expand_to_aliases(given_answers, make_sub_answers=False):
  if make_sub_answers:
    # if answers are longer than one word, make sure a predictions is correct if it coresponds to the complete 1: or :-1 sub word
    # *e.g.* if the correct answer contains a prefix such as "the", or "a"
    given_answers = given_answers + get_sub_answers(given_answers, begin=1) + get_sub_answers(given_answers, end=-1)
  answers = []
  for answer in given_answers:
    alias = answer.replace('_', ' ').lower()
    alias = ''.join(c if c not in PUNCTUATION_SET_TO_EXCLUDE else ' ' for c in alias)
    answers.append(' '.join(alias.split()).strip())
  return set(answers)

In [47]:
def get_best_valid_start_end_idx(start_scores, end_scores, top_k=1, max_size=100):
    best_start_scores, best_start_idx = torch.topk(start_scores, top_k)
    best_end_scores, best_end_idx = torch.topk(end_scores, top_k)

    widths = best_end_idx[:, None] - best_start_idx[None, :]
    mask = torch.logical_or(widths < 0, widths > max_size)
    scores = (best_end_scores[:, None] + best_start_scores[None, :]) - (1e8 * mask)
    best_score = torch.argmax(scores).item()

    return best_start_idx[best_score % top_k], best_end_idx[best_score // top_k]

In [48]:
def evaluate(example):
    # encode question and context so that they are seperated by a tokenizer.sep_token and cut at max_length
    inputs = tokenizer(example["question"], example["context"], return_tensors="pt", max_length=4096, padding="max_length", truncation=True)
    inputs = {k: inputs[k].to(DEVICE) for k in inputs}

    with torch.no_grad():
        outputs = model(**inputs)
        start_scores = outputs['start_logits']
        end_scores = outputs['end_logits']
        _, category = outputs["cls_out"].max(dim=-1)

    predicted_category = CATEGORY_MAPPING[category.item()]

    example['targets'] = example['long'] + example['short']
    if example['category'] in ['yes', 'no', 'null']:
      example['targets'] = [example['category']]
    example['has_tgt'] = example['category'] != 'null'
    # Now target can be: "yes", "no", "null", "list of long & short answers"

    if predicted_category in ['yes', 'no', 'null']:
        example['output'] = [predicted_category]
        example['match'] = example['output'] == example['targets']
        example['has_pred'] = predicted_category != 'null'
        return example

    max_size = 38 if predicted_category == "short" else 1024
    start_score, end_score = get_best_valid_start_end_idx(start_scores[0], end_scores[0], top_k=8, max_size=max_size)

    input_ids = inputs["input_ids"][0].cpu().tolist()
    example["output"] = [tokenizer.decode(input_ids[start_score: end_score+1])]

    answers = expand_to_aliases(example["targets"], make_sub_answers=True)
    predictions = expand_to_aliases(example["output"])

    # if there is a common element, it's a exact match
    example["match"] = len(list(answers & predictions)) > 0
    example["has_pred"] = predicted_category != 'null' and len(predictions) > 0

    return example

In [13]:
import torch
import numpy as np
from train_nq import BigBirdForNaturalQuestions
from params import CATEGORY_MAPPING
from transformers import BigBirdTokenizer

CATEGORY_MAPPING = {v: k for k, v in CATEGORY_MAPPING.items()}
CATEGORY_MAPPING

{0: 'null', 1: 'short', 2: 'long', 3: 'yes', 4: 'no'}

In [14]:
MODEL_ID = "vasudevgupta/bigbird-roberta-natural-questions"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
revision = "b962e30f2367cbc5e35b2c0d64faa9bad469e2e2"

model = BigBirdForNaturalQuestions.from_pretrained(MODEL_ID, revision=revision).to(DEVICE)
tokenizer = BigBirdTokenizer.from_pretrained(MODEL_ID, revision=revision)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=837.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=528910842.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=845731.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=775.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1243.0, style=ProgressStyle(description…




In [49]:
def evaluate_print(example, verbose=0):
  example = evaluate(example)
  if verbose != 0:
    print("TARGET", example["short"] if len(example["short"]) > 0 else example['long'])
    print("PREDICTION", example["output"], end="\n\n")
  return example

short_validation_dataset = short_validation_dataset.map(lambda x: evaluate_print(x, 0))

HBox(children=(FloatProgress(value=0.0, max=1823.0), HTML(value='')))




In [50]:
total = len(short_validation_dataset)
matched = len(short_validation_dataset.filter(lambda x: x["match"] == 1))
print("EM score:", (matched / total)*100, "%")

HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


EM score: 47.44925946242458 %


We are getting **Exact Match ~ 47.45** 💥💥
<!-- # this f1 is as per official nq script from here (https://github.com/google-research-datasets/natural-questions/blob/master/nq_eval.py)
has_pred = len(short_validation_dataset.filter(lambda x: x["has_pred"]))
has_tgt = len(short_validation_dataset.filter(lambda x: x["has_tgt"]))
matched = len(short_validation_dataset.filter(lambda x: x["match"]))
precision = matched / has_pred
recall = matched / has_tgt
print("F1 score:", 2*precision*recall / (precision + recall)) -->