## BigBird Finetuning

In [None]:
%%capture
!git clone https://abhinav-bohra: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@github.com/abhinav-bohra/BigBird.git
%cd BigBird
!pip install git+https://github.com/google-research/bigbird.git -q

## Set options

In [None]:
from bigbird.core import flags
from bigbird.core import modeling
from bigbird.core import utils
from bigbird.summarization import run_summarization
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
import tensorflow_text as tft
from tqdm import tqdm
import sys

FLAGS = flags.FLAGS
if not hasattr(FLAGS, "f"): flags.DEFINE_string("f", "", "")
FLAGS(sys.argv)

tf.enable_v2_behavior()

In [None]:
FLAGS.data_dir = "pubmed"
FLAGS.attention_type = "block_sparse"
FLAGS.couple_encoder_decoder = True
FLAGS.max_encoder_length = 3072  # reduce for quicker demo on free colab
FLAGS.max_decoder_length = 256
FLAGS.block_size = 64
FLAGS.learning_rate = 1e-5
FLAGS.num_train_steps = 10000
FLAGS.attention_probs_dropout_prob = 0.0
FLAGS.hidden_dropout_prob = 0.0
FLAGS.use_gradient_checkpointing = True
FLAGS.vocab_model_file = "gpt2"

In [None]:
transformer_config = flags.as_dictionary()

## Summarization model

In [None]:
from tensorflow.python.ops.variable_scope import EagerVariableStore
container = EagerVariableStore()

In [None]:
with container.as_default():
  model = modeling.TransformerModel(transformer_config)

In [None]:
@tf.function(experimental_compile=True)
def fwd_bwd(features, labels):
  with tf.GradientTape() as g:
    (llh, logits, pred_ids), _ = model(features, target_ids=labels,
                                       training=True)
    loss = run_summarization.padded_cross_entropy_loss(
        logits, labels,
        transformer_config["label_smoothing"],
        transformer_config["vocab_size"])
  grads = g.gradient(loss, model.trainable_weights)
  return loss, llh, logits, pred_ids, grads

## Dataset pipeline

In [None]:
train_input_fn = run_summarization.input_fn_builder(
        data_dir=FLAGS.data_dir,
        vocab_model_file=FLAGS.vocab_model_file,
        max_encoder_length=FLAGS.max_encoder_length,
        max_decoder_length=FLAGS.max_decoder_length,
        substitute_newline=FLAGS.substitute_newline,
        is_training=True)
dataset = train_input_fn({'batch_size': 8})

In [None]:
dataset

<PaddedBatchDataset element_spec=(TensorSpec(shape=(8, 3072), dtype=tf.int32, name=None), TensorSpec(shape=(8, 256), dtype=tf.int32, name=None))>

In [None]:
# inspect at a few examples
for ex in dataset.take(1):
  print(ex)

In [None]:
import tensorflow as tf

for example in tf.io.tf_record_iterator("/content/BigBird/pubmed/test.tfrecord-00000-of-00001"):
    print(tf.train.Example.FromString(example))

AttributeError: ignored

## Check outputs

In [None]:
# loss, llh, logits, pred_ids, grads = fwd_bwd(ex[0], ex[1])
# print('Loss: ', loss)

## Load pretrained model

In [None]:
ckpt_path = 'gs://bigbird-transformer/summarization/pubmed/roberta/model.ckpt-300000'
ckpt_reader = tf.compat.v1.train.NewCheckpointReader(ckpt_path)
loaded_weights = []
for v in tqdm(model.trainable_weights, position=0):
  try:
    val = ckpt_reader.get_tensor(v.name[:-2])
  except:
    val = v.numpy()
  loaded_weights.append(val)

model.set_weights(loaded_weights)

100%|██████████| 316/316 [02:25<00:00,  2.17it/s]


## Train (Fine-tune)

In [None]:
opt = tf.keras.optimizers.Adam(FLAGS.learning_rate)
train_loss = tf.keras.metrics.Mean(name='train_loss')

for i, ex in enumerate(tqdm(dataset.take(FLAGS.num_train_steps), position=0)):
  loss, llh, logits, pred_ids, grads = fwd_bwd(ex[0], ex[1])
  opt.apply_gradients(zip(grads, model.trainable_weights))
  train_loss(loss)
  if i% 10 == 0:
    print('Loss = {} '.format(train_loss.result().numpy()))

  0%|          | 0/10000 [00:00<?, ?it/s]


InvalidArgumentError: ignored

## Eval

In [None]:
@tf.function(experimental_compile=True)
def fwd_only(features, labels):
  (llh, logits, pred_ids), _ = model(features, target_ids=labels, training=False)
  return llh, logits, pred_ids

In [None]:
eval_input_fn = run_summarization.input_fn_builder(
        data_dir=FLAGS.data_dir,
        vocab_model_file=FLAGS.vocab_model_file,
        max_encoder_length=FLAGS.max_encoder_length,
        max_decoder_length=FLAGS.max_decoder_length,
        substitute_newline=FLAGS.substitute_newline,
        is_training=False)
eval_dataset = eval_input_fn({'batch_size': 8})

In [None]:
eval_llh = tf.keras.metrics.Mean(name='eval_llh')

for ex in tqdm(eval_dataset, position=0):
  llh, logits, pred_ids = fwd_only(ex[0], ex[1])
  eval_llh(llh)
print('Log Likelihood = {}'.format(eval_llh.result().numpy()))

### Print predictions

In [None]:
tokenizer = tft.SentencepieceTokenizer(model=tf.io.gfile.GFile(FLAGS.vocab_model_file, "rb").read())

In [None]:
_, _, pred_ids = fwd_only(ex[0], ex[1])

In [None]:
print('Article:\n {}\n\n Predicted summary:\n {}\n\n Ground truth summary:\n {}\n\n'.format(
    tokenizer.detokenize(ex[0]),
    tokenizer.detokenize(pred_ids),
    tokenizer.detokenize(ex[1])))

## Prepare Data for BigBird

In [None]:
import os, json

def preprocessArticles(a_):
    a = []
    for line in a_:
        if line.startswith("Speaker ::") or line=="\n":
            continue
        else:
            a.append(line.replace('\n',"").replace("  "," "))
    return a

def preprocessSummaries(s_):
    s = []
    for line in s_:
        if line.startswith("Speaker ::") or line=="\n":
            continue
        else:
            line = line.replace('\n',"").replace("  "," ")
            line = "<S> " + str(line) + " </S>"
            s.append(line)
    return s


splits = ["train", "test", "val"] 
data = []

for split in splits:
  path_articles = f"/content/BigBird/data/reuters/original/{split}/ects" 
  path_summaries = f"/content/BigBird/data/reuters/original/{split}/gt_summaries"
  if split == "val":
      split = "validation"
  outfile = f"/content/BigBird/data/reuters/bigbird/{split}.txt"
  #outfile = f"/content/BigBird/pubmed/{split}.txt"
  #outfile = f"/root/tensorflow_datasets/downloads/extracted/ZIP.ucid_1lvsqvsFi3W-pE1SqNZI0s8NR_export_download1CQHRyal4p4gv4NAVf5-_pD4o3vOCitRLkq35IcBPAQ/pubmed-dataset/{split}.txt"
  articles = os.listdir(path_articles)
  summaries = os.listdir(path_summaries)
  print(split, len(articles), len(summaries))
  for article in articles:
    data_point = {}
    a_ = open(os.path.join(path_articles, article), 'r').readlines()
    a = preprocessArticles(a_)
    s_ = open(os.path.join(path_summaries, article), 'r').readlines()
    s = preprocessSummaries(s_)
    data_point["article_id"] = article
    data_point["article_text"] = a
    data_point["abstract_text"] = s
    data_point["section_names"] = "null"
    data_point["labels"] = "null"
    data_str = json.dumps(data_point)
    data.append(data_str)
  with open(outfile, mode='wt', encoding='utf-8') as myfile:
    myfile.write('\n'.join(data))
    print(f"{split} done")
  myfile.close()

train 1350 1350
train done
test 482 482
test done
validation 150 150
validation done


In [None]:
# %cd /content/pubmed
# !tfds build

In [None]:
# !cp /root/tensorflow_datasets/scientific_papers/pubmed/1.1.1/* /content/BigBird/pubmed/

In [None]:
!git add .
!git config --global user.email "abhinavbohra@iitkgp.ac.in"
!git config --global user.name "abhinav-bohra"
!git commit -m "Added reuters-bigbird dataset"
!git push

In [None]:
!git pull