<a href="https://colab.research.google.com/github/vasudevgupta7/bigbird/blob/main/notebooks/bigbird_tpu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 🤗's `BigBird` on TPUs

## Basic Setup for accessing colab-TPU

In [11]:
%%capture
!pip3 install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl
!pip3 install sentencepiece

In [None]:
# !pip3 uninstall transformers
!pip3 install git+https://github.com/vasudevgupta7/transformers@bigbird-tpu

In [None]:
from transformers import BigBirdForQuestionAnswering, BigBirdTokenizer
import torch
import torch_xla.core.xla_model as xm
import numpy as np

device = xm.xla_device()



## Inference `BigBirdForQuestionAnswering` on TPU


In [None]:
tokenizer = BigBirdTokenizer.from_pretrained("google/bigbird-base-trivia-itc")
model = BigBirdForQuestionAnswering.from_pretrained("google/bigbird-base-trivia-itc", block_size=16).to(device)
model.device

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=845731.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=775.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=943.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=790.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=526574331.0, style=ProgressStyle(descri…




device(type='xla', index=1)

In [None]:
context = "The BigBird model was proposed in Big Bird: Transformers for Longer Sequences by Zaheer, Manzil and Guruganesh, Guru and Dubey, Kumar Avinava and Ainslie, Joshua and Alberti, Chris and Ontanon, Santiago and Pham, Philip and Ravula, Anirudh and Wang, Qifan and Yang, Li and others. BigBird, is a sparse-attention based transformer which extends Transformer based models, such as BERT to much longer sequences. In addition to sparse attention, BigBird also applies global attention as well as random attention to the input sequence. Theoretically, it has been shown that applying sparse, global, and random attention approximates full attention, while being computationally much more efficient for longer sequences. As a consequence of the capability to handle longer context, BigBird has shown improved performance on various long document NLP tasks, such as question answering and summarization, compared to BERT or RoBERTa."
question = ["Which is better for longer sequences- BigBird or BERT?", "What is the benefit of using BigBird over BERT?"]

In [None]:
inputs = tokenizer(
    question,
    [context, context],
    padding="max_length",
    return_tensors="pt",
    add_special_tokens=True,
    max_length=512,
    truncation=True,
)
inputs = {k: v.to(device) for k, v in inputs.items()}

In [None]:
with torch.no_grad():
  start_logits, end_logits = model(**inputs).to_tuple()

In [None]:
start_logits, end_logits

(tensor([[-1.0000e+06, -1.0000e+06, -1.0000e+06,  ..., -3.5154e+00,
          -3.5228e+00, -4.2024e+00],
         [-1.0000e+06, -1.0000e+06, -1.0000e+06,  ..., -3.5154e+00,
          -3.5228e+00, -4.2024e+00]], device='xla:1'),
 tensor([[-1.0000e+06, -1.0000e+06, -1.0000e+06,  ..., -7.0817e+00,
          -6.8335e+00, -7.8566e+00],
         [-1.0000e+06, -1.0000e+06, -1.0000e+06,  ..., -7.0817e+00,
          -6.8335e+00, -7.8566e+00]], device='xla:1'))

In [None]:
input_ids = inputs["input_ids"].tolist()
start = np.argmax(start_logits.detach().cpu().numpy(), axis=-1)
end = np.argmax(end_logits.detach().cpu().numpy(), axis=-1)
answer = [input_ids[i][start[i] : end[i] + 1] for i in range(len(input_ids))]
answer = tokenizer.batch_decode(answer)
answer

['BigBird', 'global attention']

## Training `BigBirdForQuestionAnswering` on TPU

In [1]:
!git clone https://github.com/vasudevgupta7/bigbird

Cloning into 'bigbird'...
remote: Enumerating objects: 254, done.[K
remote: Counting objects: 100% (254/254), done.[K
remote: Compressing objects: 100% (205/205), done.[K
remote: Total 254 (delta 127), reused 133 (delta 41), pack-reused 0[K
Receiving objects: 100% (254/254), 8.57 MiB | 25.58 MiB/s, done.
Resolving deltas: 100% (127/127), done.


In [4]:
cd /content/bigbird/natural-questions

/content/bigbird/natural-questions


In [None]:
!mkdir data
# !wget https://huggingface.co/datasets/vasudevgupta/bigbird-tokenized-natural-questions/resolve/main/nq-train.zip -P data && unzip data/nq-train.zip -d data/
!wget https://huggingface.co/datasets/vasudevgupta/bigbird-tokenized-natural-questions/resolve/main/nq-val.zip -P data && unzip data/nq-val.zip -d data/

In [5]:
ls

[0m[01;34mdata[0m/           params.py      [01;34m__pycache__[0m/      train_nq.py
evaluate_nq.py  prepare_nq.py  requirements.txt


In [13]:
%%capture
!pip3 install datasets
!pip3 install wandb

In [None]:
# replacing training.jsonl with validation.jsonl for working in colab

!TRAIN_ON_SMALL=True python3 train_nq_tpu.py