<a href="https://colab.research.google.com/github/vasudevgupta7/bigbird/blob/main/notebooks/bigbird_tpu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 🤗's `BigBird` on TPUs

## Basic Setup for accessing colab-TPU

In [1]:
%%capture
!pip3 install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl
!pip3 install transformers
!pip3 install sentencepiece

In [2]:
from transformers import BigBirdForQuestionAnswering, BigBirdTokenizer
import torch
import torch_xla.core.xla_model as xm
import numpy as np

device = xm.xla_device()



## Inference `BigBirdForQuestionAnswering` on TPU


In [3]:
tokenizer = BigBirdTokenizer.from_pretrained("google/bigbird-base-trivia-itc")
model = BigBirdForQuestionAnswering.from_pretrained("google/bigbird-base-trivia-itc", block_size=16).to(device)
model.device

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=845731.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=775.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=943.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=790.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=526574331.0, style=ProgressStyle(descri…




device(type='xla', index=1)

In [4]:
context = "The BigBird model was proposed in Big Bird: Transformers for Longer Sequences by Zaheer, Manzil and Guruganesh, Guru and Dubey, Kumar Avinava and Ainslie, Joshua and Alberti, Chris and Ontanon, Santiago and Pham, Philip and Ravula, Anirudh and Wang, Qifan and Yang, Li and others. BigBird, is a sparse-attention based transformer which extends Transformer based models, such as BERT to much longer sequences. In addition to sparse attention, BigBird also applies global attention as well as random attention to the input sequence. Theoretically, it has been shown that applying sparse, global, and random attention approximates full attention, while being computationally much more efficient for longer sequences. As a consequence of the capability to handle longer context, BigBird has shown improved performance on various long document NLP tasks, such as question answering and summarization, compared to BERT or RoBERTa."
question = ["Which is better for longer sequences- BigBird or BERT?", "What is the benefit of using BigBird over BERT?"]

In [5]:
inputs = tokenizer(
    question,
    [context, context],
    padding="max_length",
    return_tensors="pt",
    add_special_tokens=True,
    max_length=512,
    truncation=True,
)
inputs = {k: v.to(device) for k, v in inputs.items()}

In [6]:
with torch.no_grad():
  start_logits, end_logits = model(inputs["input_ids"]).to_tuple()

In [7]:
start_logits, end_logits

(tensor([[-1.0000e+06, -1.0000e+06, -1.0000e+06,  ..., -1.3038e+01,
          -1.3341e+01, -1.3470e+01],
         [-1.0000e+06, -1.0000e+06, -1.0000e+06,  ..., -1.2755e+01,
          -1.2914e+01, -1.3003e+01]], device='xla:1'),
 tensor([[-1.0000e+06, -1.0000e+06, -1.0000e+06,  ..., -1.5755e+01,
          -1.5940e+01, -1.5738e+01],
         [-1.0000e+06, -1.0000e+06, -1.0000e+06,  ..., -1.5083e+01,
          -1.5300e+01, -1.5231e+01]], device='xla:1'))

In [8]:
input_ids = inputs["input_ids"].tolist()
start = np.argmax(start_logits.cpu().detach().numpy(), axis=-1)
end = np.argmax(end_logits.cpu().detach().numpy(), axis=-1)
answer = [input_ids[i][start[i] : end[i] + 1] for i in range(len(input_ids))]
answer = tokenizer.batch_decode(answer)
answer

['BigBird', 'global attention']

## Training `BigBirdForQuestionAnswering` on TPU