
<a href="https://colab.research.google.com/github/google-research/bigbird/blob/master/bigbird/classifier/imdb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##### Copyright 2020 The BigBird Authors

Licensed under the Apache License, Version 2.0 (the "License");

In [1]:
# Copyright 2020 The BigBird Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

## Set Up

In [1]:
!pip install git+https://github.com/google-research/bigbird.git -q
!pip install wandb -qqq
import wandb

[K     |████████████████████████████████| 1.2 MB 10.5 MB/s 
[K     |████████████████████████████████| 4.3 MB 57.9 MB/s 
[K     |████████████████████████████████| 1.4 MB 58.9 MB/s 
[K     |████████████████████████████████| 4.0 MB 23.5 MB/s 
[K     |████████████████████████████████| 679 kB 74.7 MB/s 
[K     |████████████████████████████████| 79 kB 9.6 MB/s 
[K     |████████████████████████████████| 352 kB 76.5 MB/s 
[K     |████████████████████████████████| 5.8 MB 21.4 MB/s 
[K     |████████████████████████████████| 649 kB 66.6 MB/s 
[K     |████████████████████████████████| 981 kB 56.4 MB/s 
[K     |████████████████████████████████| 366 kB 79.8 MB/s 
[K     |████████████████████████████████| 191 kB 79.2 MB/s 
[K     |████████████████████████████████| 365 kB 72.9 MB/s 
[K     |████████████████████████████████| 251 kB 73.5 MB/s 
[K     |████████████████████████████████| 191 kB 73.6 MB/s 
[K     |████████████████████████████████| 178 kB 81.7 MB/s 
[?25h  Building wheel for

In [2]:
wandb.init(project="uncategorized", entity='notyeshwanthreddy')
config = wandb.config

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize


wandb: Paste an API key from your profile and hit enter: ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [3]:
from bigbird.core import flags
from bigbird.core import modeling
from bigbird.core import utils
from bigbird.classifier import run_classifier
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
from tqdm import tqdm
import sys

FLAGS = flags.FLAGS
if not hasattr(FLAGS, "f"): flags.DEFINE_string("f", "", "")
FLAGS(sys.argv)

tf.enable_v2_behavior()

## Set options

In [11]:
FLAGS.data_dir = "tfds://imdb_reviews/plain_text"
FLAGS.attention_type = "block_sparse"
# FLAGS.max_encoder_length = 4096  # reduce for quicker demo on free colab
FLAGS.max_encoder_length = 1024  # reduce for quicker demo on free colab
FLAGS.learning_rate = 1e-5
FLAGS.num_train_steps = 2000
FLAGS.attention_probs_dropout_prob = 0.0
FLAGS.hidden_dropout_prob = 0.0
FLAGS.use_gradient_checkpointing = True
FLAGS.vocab_model_file = "gpt2"

In [12]:
bert_config = flags.as_dictionary()

## Define classification model

In [13]:
model = modeling.BertModel(bert_config)
headl = run_classifier.ClassifierLossLayer(
        bert_config["hidden_size"], bert_config["num_labels"],
        bert_config["hidden_dropout_prob"],
        utils.create_initializer(bert_config["initializer_range"]),
        name=bert_config["scope"]+"/classifier")

In [14]:
@tf.function(experimental_compile=True)
def fwd_bwd(features, labels):
  with tf.GradientTape() as g:
    _, pooled_output = model(features, training=True)
    loss, log_probs = headl(pooled_output, labels, True)
  grads = g.gradient(loss, model.trainable_weights+headl.trainable_weights)
  return loss, log_probs, grads

## Dataset pipeline

In [15]:
train_input_fn = run_classifier.input_fn_builder(
        data_dir=FLAGS.data_dir,
        vocab_model_file=FLAGS.vocab_model_file,
        max_encoder_length=FLAGS.max_encoder_length,
        substitute_newline=FLAGS.substitute_newline,
        is_training=True)
dataset = train_input_fn({'batch_size': 8})

  deterministic=is_training)


In [16]:
# inspect at a few examples
for ex in dataset.take(3):
  print(ex)

(<tf.Tensor: shape=(8, 1024), dtype=int32, numpy=
array([[   65,   733,   474, ...,     0,     0,     0],
       [   65,   415, 26500, ...,     0,     0,     0],
       [   65,   484, 20677, ...,     0,     0,     0],
       ...,
       [   65,   418,  1150, ...,     0,     0,     0],
       [   65,  9271,  5714, ...,     0,     0,     0],
       [   65,  8301,   113, ...,     0,     0,     0]], dtype=int32)>, <tf.Tensor: shape=(8,), dtype=int32, numpy=array([0, 1, 1, 1, 1, 0, 1, 0], dtype=int32)>)
(<tf.Tensor: shape=(8, 1024), dtype=int32, numpy=
array([[  65, 1182,  358, ...,    0,    0,    0],
       [  65,  871,  419, ...,    0,    0,    0],
       [  65,  415, 1908, ...,    0,    0,    0],
       ...,
       [  65,  484, 1722, ...,    0,    0,    0],
       [  65,  876, 1154, ...,    0,    0,    0],
       [  65,  415, 1092, ...,    0,    0,    0]], dtype=int32)>, <tf.Tensor: shape=(8,), dtype=int32, numpy=array([0, 1, 0, 0, 1, 0, 0, 1], dtype=int32)>)
(<tf.Tensor: shape=(8, 1024)

## (Optionally) Check outputs

In [17]:
loss, log_probs, grads = fwd_bwd(ex[0], ex[1])
print('Loss: ', loss.numpy())

Loss:  0.62117565


## (Optionally) Load pretrained model

In [18]:
ckpt_path = 'gs://bigbird-transformer/pretrain/bigbr_base/model.ckpt-0'
ckpt_reader = tf.compat.v1.train.NewCheckpointReader(ckpt_path)
model.set_weights([ckpt_reader.get_tensor(v.name[:-2]) for v in tqdm(model.trainable_weights, position=0)])

100%|██████████| 199/199 [00:45<00:00,  4.40it/s]


## Train

In [19]:
opt = tf.keras.optimizers.Adam(FLAGS.learning_rate)
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')

for i, ex in enumerate(tqdm(dataset.take(FLAGS.num_train_steps), position=0)):
  loss, log_probs, grads = fwd_bwd(ex[0], ex[1])
  opt.apply_gradients(zip(grads, model.trainable_weights+headl.trainable_weights))
  train_loss(loss)
  train_accuracy(tf.one_hot(ex[1], 2), log_probs)
  if i% 200 == 0:
    print('Loss = {}  Accuracy = {}'.format(train_loss.result().numpy(), train_accuracy.result().numpy()))

  0%|          | 1/2000 [00:07<4:25:22,  7.97s/it]

Loss = 0.7722638845443726  Accuracy = 0.375


 10%|█         | 201/2000 [07:28<1:05:42,  2.19s/it]

Loss = 0.4163071811199188  Accuracy = 0.8252487778663635


 20%|██        | 401/2000 [14:48<58:16,  2.19s/it]

Loss = 0.34832972288131714  Accuracy = 0.8615959882736206


 30%|███       | 601/2000 [22:08<51:12,  2.20s/it]

Loss = 0.31350818276405334  Accuracy = 0.877703845500946


 40%|████      | 801/2000 [29:27<44:03,  2.20s/it]

Loss = 0.2841528058052063  Accuracy = 0.892166018486023


 50%|█████     | 1001/2000 [36:47<36:28,  2.19s/it]

Loss = 0.26684820652008057  Accuracy = 0.897602379322052


 60%|██████    | 1201/2000 [44:07<29:23,  2.21s/it]

Loss = 0.25541743636131287  Accuracy = 0.9019566774368286


 70%|███████   | 1401/2000 [51:26<22:01,  2.21s/it]

Loss = 0.24668189883232117  Accuracy = 0.9052462577819824


 80%|████████  | 1601/2000 [58:45<14:39,  2.20s/it]

Loss = 0.23745958507061005  Accuracy = 0.9099000692367554


 90%|█████████ | 1801/2000 [1:06:05<07:17,  2.20s/it]

Loss = 0.23041319847106934  Accuracy = 0.9128261804580688


100%|██████████| 2000/2000 [1:13:22<00:00,  2.20s/it]


## Eval

In [20]:
@tf.function(experimental_compile=True)
def fwd_only(features, labels):
  _, pooled_output = model(features, training=False)
  loss, log_probs = headl(pooled_output, labels, False)
  return loss, log_probs

In [21]:
eval_input_fn = run_classifier.input_fn_builder(
        data_dir=FLAGS.data_dir,
        vocab_model_file=FLAGS.vocab_model_file,
        max_encoder_length=FLAGS.max_encoder_length,
        substitute_newline=FLAGS.substitute_newline,
        is_training=False)
eval_dataset = eval_input_fn({'batch_size': 8})

  deterministic=is_training)


In [40]:
eval_loss = tf.keras.metrics.Mean(name='eval_loss')
eval_accuracy = tf.keras.metrics.CategoricalAccuracy(name='eval_accuracy')

for ex in tqdm(eval_dataset, position=0):
  loss, log_probs = fwd_only(ex[0], ex[1])
  eval_loss(loss)
  eval_accuracy(tf.one_hot(ex[1], 2), log_probs)
print('Loss = {}  Accuracy = {}'.format(eval_loss.result().numpy(), eval_accuracy.result().numpy()))

100%|██████████| 3125/3125 [30:52<00:00,  1.69it/s]

Loss = 0.15144120156764984  Accuracy = 0.9464799761772156





In [41]:
for i in (eval_dataset.take(1)):
  print(i)

(<tf.Tensor: shape=(8, 1024), dtype=int32, numpy=
array([[   65,  1419,   490, ...,     0,     0,     0],
       [   65,   418,  2143, ...,     0,     0,     0],
       [   65, 23777, 41003, ...,     0,     0,     0],
       ...,
       [   65,  1547,   661, ...,     0,     0,     0],
       [   65,   871,  2747, ...,     0,     0,     0],
       [   65,   484,  3451, ...,     0,     0,     0]], dtype=int32)>, <tf.Tensor: shape=(8,), dtype=int32, numpy=array([1, 1, 0, 1, 1, 0, 1, 1], dtype=int32)>)
