
<a href="https://colab.research.google.com/github/google-research/bigbird/blob/master/bigbird/summarization/pubmed.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##### Copyright 2020 The BigBird Authors

Licensed under the Apache License, Version 2.0 (the "License");

In [1]:
# Copyright 2020 The BigBird Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

## Set Up

In [None]:
!pip install git+https://github.com/google-research/bigbird.git -q

In [6]:
import os
os.chdir("/home/gitlib/longsumm")
os.getcwd()

'/home/gitlib/longsumm'

In [8]:
from bigbird.core import flags
from bigbird.core import modeling
from bigbird.core import utils
from bigbird.summarization import run_summarization
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
import tensorflow_text as tft
from tqdm import tqdm
import sys

FLAGS = flags.FLAGS
if not hasattr(FLAGS, "f"): flags.DEFINE_string("f", "", "")
FLAGS(sys.argv)

tf.enable_v2_behavior()

## Set options

In [9]:
FLAGS.data_dir = "tfds://scientific_papers/pubmed"
FLAGS.attention_type = "block_sparse"
FLAGS.couple_encoder_decoder = True
FLAGS.max_encoder_length = 2048  # on free colab only lower memory GPU like T4 is available
FLAGS.max_decoder_length = 256
FLAGS.block_size = 64
FLAGS.learning_rate = 1e-5
FLAGS.num_train_steps = 10000
FLAGS.attention_probs_dropout_prob = 0.0
FLAGS.hidden_dropout_prob = 0.0
FLAGS.vocab_model_file = "gpt2"

In [10]:
transformer_config = flags.as_dictionary()

## Define summarization model

In [11]:
from tensorflow.python.ops.variable_scope import EagerVariableStore
container = EagerVariableStore()

In [12]:
with container.as_default():
  model = modeling.TransformerModel(transformer_config)

InternalError: CUDA runtime implicit initialization on GPU:0 failed. Status: all CUDA-capable devices are busy or unavailable

In [None]:
@tf.function(experimental_compile=True)
def fwd_bwd(features, labels):
  with tf.GradientTape() as g:
    (llh, logits, pred_ids), _ = model(features, target_ids=labels,
                                       training=True)
    loss = run_summarization.padded_cross_entropy_loss(
        logits, labels,
        transformer_config["label_smoothing"],
        transformer_config["vocab_size"])
  grads = g.gradient(loss, model.trainable_weights)
  return loss, llh, logits, pred_ids, grads

## Dataset pipeline

In [None]:
train_input_fn = run_summarization.input_fn_builder(
        data_dir=FLAGS.data_dir,
        vocab_model_file=FLAGS.vocab_model_file,
        max_encoder_length=FLAGS.max_encoder_length,
        max_decoder_length=FLAGS.max_decoder_length,
        substitute_newline=FLAGS.substitute_newline,
        is_training=True)
dataset = train_input_fn({'batch_size': 2})

In [None]:
# inspect at a few examples
for ex in dataset.take(3):
  print(ex)

## Check outputs

In [None]:
with container.as_default():
  loss, llh, logits, pred_ids, grads = fwd_bwd(ex[0], ex[1])
print('Loss: ', loss)

## (Optionally) Load pretrained model

In [None]:
# For training from scratch use
# ckpt_path = 'gs://bigbird-transformer/pretrain/bigbr_base/model.ckpt-0'
# For quick check continue from trained checkpoint
ckpt_path = 'gs://bigbird-transformer/summarization/pubmed/roberta/model.ckpt-300000'
ckpt_reader = tf.compat.v1.train.NewCheckpointReader(ckpt_path)
loaded_weights = []
for v in tqdm(model.trainable_weights, position=0):
  try:
    val = ckpt_reader.get_tensor(v.name[:-2])
  except:
    val = v.numpy()
  loaded_weights.append(val)

model.set_weights(loaded_weights)

## Train

In [None]:
opt = tf.keras.optimizers.Adam(FLAGS.learning_rate)
train_loss = tf.keras.metrics.Mean(name='train_loss')

for i, ex in enumerate(tqdm(dataset.take(FLAGS.num_train_steps), position=0)):
  loss, llh, logits, pred_ids, grads = fwd_bwd(ex[0], ex[1])
  opt.apply_gradients(zip(grads, model.trainable_weights))
  train_loss(loss)
  if i% 10 == 0:
    print('Loss = {} '.format(train_loss.result().numpy()))

## Eval

In [None]:
@tf.function(experimental_compile=True)
def fwd_only(features, labels):
  (llh, logits, pred_ids), _ = model(features, target_ids=labels,
                                       training=False)
  return llh, logits, pred_ids

In [None]:
eval_input_fn = run_summarization.input_fn_builder(
        data_dir=FLAGS.data_dir,
        vocab_model_file=FLAGS.vocab_model_file,
        max_encoder_length=FLAGS.max_encoder_length,
        max_decoder_length=FLAGS.max_decoder_length,
        substitute_newline=FLAGS.substitute_newline,
        is_training=False)
eval_dataset = eval_input_fn({'batch_size': 2})

In [None]:
eval_llh = tf.keras.metrics.Mean(name='eval_llh')

for ex in tqdm(eval_dataset, position=0):
  llh, logits, pred_ids = fwd_only(ex[0], ex[1])
  eval_llh(llh)
print('Log Likelihood = {}'.format(eval_llh.result().numpy()))

### Print predictions

In [None]:
tokenizer = tft.SentencepieceTokenizer(
        model=tf.io.gfile.GFile(FLAGS.vocab_model_file, "rb").read())

In [None]:
_, _, pred_ids = fwd_only(ex[0], ex[1])

In [None]:
print('Article:\n {}\n\n Predicted summary:\n {}\n\n Ground truth summary:\n {}\n\n'.format(
    tokenizer.detokenize(ex[0]),
    tokenizer.detokenize(pred_ids),
    tokenizer.detokenize(ex[1])))

Article:
 [b'although injuries to the flexor tendons in the zone ii region look trivial , sustained commitment of the patient , the surgeon and the therapist is necessary to get a reasonable functional outcome . as our institute \xe2\x81\x87 is situated in an industrial corridor of the city , most of our patients are manual workers with poor compliance . conforming to the established practice , we tried various early mobilization protocols after zone ii flexor tendon repair . \xe2\x81\x87 as these protocols demanded a level of understanding and a degree of dedication from the patients , our results were suboptimal , with a high incidence of proximal interphalangeal joint ( pip ) joint flexion contractures and tendon ruptures . \xe2\x81\x87 hence , in our institute immobilization became the norm for patients who were not expected to be compliant . \xe2\x81\x87 to improve the results in such patients , we thought of implementing a new rehabilitation protocol that could entirely be under 