
<a href="https://colab.research.google.com/github/google-research/bigbird/blob/master/bigbird/classifier/imdb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##### Copyright 2020 The BigBird Authors

Licensed under the Apache License, Version 2.0 (the "License");

In [None]:
# Copyright 2020 The BigBird Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

## Set Up

In [None]:
!pip install git+https://github.com/google-research/bigbird.git -q

In [None]:
from bigbird.core import flags
from bigbird.core import modeling
from bigbird.core import utils
from bigbird.classifier import run_classifier
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
from tqdm import tqdm
import sys

FLAGS = flags.FLAGS
if not hasattr(FLAGS, "f"): flags.DEFINE_string("f", "", "")
FLAGS(sys.argv)

tf.enable_v2_behavior()

## Set options

In [None]:
FLAGS.data_dir = "tfds://imdb_reviews/plain_text"
FLAGS.attention_type = "block_sparse"
FLAGS.max_encoder_length = 4096  # reduce for quicker demo on free colab
FLAGS.learning_rate = 1e-5
FLAGS.num_train_steps = 2000
FLAGS.attention_probs_dropout_prob = 0.0
FLAGS.hidden_dropout_prob = 0.0
FLAGS.use_gradient_checkpointing = True
FLAGS.vocab_model_file = "gpt2"

In [None]:
bert_config = flags.as_dictionary()

## Define classification model

In [None]:
model = modeling.BertModel(bert_config)
headl = run_classifier.ClassifierLossLayer(
        bert_config["hidden_size"], bert_config["num_labels"],
        bert_config["hidden_dropout_prob"],
        utils.create_initializer(bert_config["initializer_range"]),
        name=bert_config["scope"]+"/classifier")

In [None]:
@tf.function(experimental_compile=True)
def fwd_bwd(features, labels):
  with tf.GradientTape() as g:
    _, pooled_output = model(features, training=True)
    loss, log_probs = headl(pooled_output, labels, True)
  grads = g.gradient(loss, model.trainable_weights+headl.trainable_weights)
  return loss, log_probs, grads

## Dataset pipeline

In [None]:
train_input_fn = run_classifier.input_fn_builder(
        data_dir=FLAGS.data_dir,
        vocab_model_file=FLAGS.vocab_model_file,
        max_encoder_length=FLAGS.max_encoder_length,
        substitute_newline=FLAGS.substitute_newline,
        is_training=True)
dataset = train_input_fn({'batch_size': 8})




In [None]:
# inspect at a few examples
for ex in dataset.take(3):
  print(ex)

(<tf.Tensor: shape=(8, 4096), dtype=int32, numpy=
array([[   65,   733,   474, ...,     0,     0,     0],
       [   65,   415, 26500, ...,     0,     0,     0],
       [   65,   484, 20677, ...,     0,     0,     0],
       ...,
       [   65,   418,  1150, ...,     0,     0,     0],
       [   65,  9271,  5714, ...,     0,     0,     0],
       [   65,  8301,   113, ...,     0,     0,     0]], dtype=int32)>, <tf.Tensor: shape=(8,), dtype=int32, numpy=array([0, 1, 1, 1, 1, 0, 1, 0], dtype=int32)>)
(<tf.Tensor: shape=(8, 4096), dtype=int32, numpy=
array([[  65, 1182,  358, ...,    0,    0,    0],
       [  65,  871,  419, ...,    0,    0,    0],
       [  65,  415, 1908, ...,    0,    0,    0],
       ...,
       [  65,  484, 1722, ...,    0,    0,    0],
       [  65,  876, 1154, ...,    0,    0,    0],
       [  65,  415, 1092, ...,    0,    0,    0]], dtype=int32)>, <tf.Tensor: shape=(8,), dtype=int32, numpy=array([0, 1, 0, 0, 1, 0, 0, 1], dtype=int32)>)
(<tf.Tensor: shape=(8, 4096)

## (Optionally) Check outputs

In [None]:
loss, log_probs, grads = fwd_bwd(ex[0], ex[1])
print('Loss: ', loss.numpy())


Loss:  0.6977416





## (Optionally) Load pretrained model

In [None]:
ckpt_path = 'gs://bigbird-transformer/pretrain/bigbr_base/model.ckpt-0'
ckpt_reader = tf.compat.v1.train.NewCheckpointReader(ckpt_path)
model.set_weights([ckpt_reader.get_tensor(v.name[:-2]) for v in tqdm(model.trainable_weights, position=0)])

100%|██████████| 199/199 [00:34<00:00,  4.94it/s]


## Train

In [None]:
opt = tf.keras.optimizers.Adam(FLAGS.learning_rate)
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')

for i, ex in enumerate(tqdm(dataset.take(FLAGS.num_train_steps), position=0)):
  loss, log_probs, grads = fwd_bwd(ex[0], ex[1])
  opt.apply_gradients(zip(grads, model.trainable_weights+headl.trainable_weights))
  train_loss(loss)
  train_accuracy(tf.one_hot(ex[1], 2), log_probs)
  if i% 200 == 0:
    print('Loss = {}  Accuracy = {}'.format(train_loss.result().numpy(), train_accuracy.result().numpy()))

Loss = 0.7094929218292236  Accuracy = 0.5

  0%|          | 0/2000 [00:06<1:59:12,  3.57it/s]


Loss = 0.47779741883277893  Accuracy = 0.7558900713920593

 10%|█         | 200/2000 [11:26<1:48:08,  3.60it/s]


Loss = 0.3703668415546417  Accuracy = 0.8318414092063904

 20%|██        | 400/2000 [23:52<1:35:17,  3.58it/s]


Loss = 0.3130376636981964  Accuracy = 0.8654822111129761

 30%|███       | 600/2000 [35:18<1:24:58,  3.60it/s]


Loss = 0.2806303799152374  Accuracy = 0.8822692632675171

 40%|████      | 800/2000 [47:44<1:12:41,  3.60it/s]


Loss = 0.2649693191051483  Accuracy = 0.8901362419128418

 50%|█████     | 1000/2000 [59:10<59:03,  3.58it/s]


Loss = 0.25240564346313477  Accuracy = 0.8967254161834717

 60%|██████    | 1200/2000 [1:11:36<47:43,  3.60it/s]


Loss = 0.24363534152507782  Accuracy = 0.901509702205658

 70%|███████   | 1400/2000 [1:23:02<35:20,  3.58it/s]


Loss = 0.23414449393749237  Accuracy = 0.9062696695327759

 80%|████████  | 1600/2000 [1:35:30<23:23,  3.60it/s]


Loss = 0.22541514039039612  Accuracy = 0.9101060628890991

 90%|█████████ | 1800/2000 [1:46:05<11:34,  3.60it/s]


Loss = 0.2210962176322937  Accuracy = 0.9125439524650574

100%|██████████| 2000/2000 [1:59:39<00:00,  3.58it/s]







## Eval

In [None]:
@tf.function(experimental_compile=True)
def fwd_only(features, labels):
  _, pooled_output = model(features, training=False)
  loss, log_probs = headl(pooled_output, labels, False)
  return loss, log_probs

In [None]:
eval_input_fn = run_classifier.input_fn_builder(
        data_dir=FLAGS.data_dir,
        vocab_model_file=FLAGS.vocab_model_file,
        max_encoder_length=FLAGS.max_encoder_length,
        substitute_newline=FLAGS.substitute_newline,
        is_training=False)
eval_dataset = eval_input_fn({'batch_size': 8})

In [None]:
eval_loss = tf.keras.metrics.Mean(name='eval_loss')
eval_accuracy = tf.keras.metrics.CategoricalAccuracy(name='eval_accuracy')

for ex in tqdm(eval_dataset, position=0):
  loss, log_probs = fwd_only(ex[0], ex[1])
  eval_loss(loss)
  eval_accuracy(tf.one_hot(ex[1], 2), log_probs)
print('Loss = {}  Accuracy = {}'.format(eval_loss.result().numpy(), eval_accuracy.result().numpy()))


Loss = 0.16173037886619568  Accuracy = 0.9459513425827026100