TRAINING

This notebook finetunes T5 for summarization of BillSum documents. Pull in train / valid data (which was created in the load data notebook) and then follow the Low Ram guidelines from class notebook to train the model without Colab crashing. Save model weights to drive so we can access them for inference.

In [None]:
# mount to colab
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install datasets
!pip install transformers
!pip install sentencepiece

Installing collected packages: sentencepiece
Successfully installed sentencepiece-0.1.99


In [None]:
import os
import re
import numpy as np
import pandas as pd

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

from transformers import AutoTokenizer  , TFAutoModel
from transformers import T5Tokenizer, TFT5ForConditionalGeneration
from pprint import pprint

In [None]:
# load in training and valid data
train_file = "drive/MyDrive/266project/billsum_train.csv"
valid_file = "drive/MyDrive/266project/billsum_valid.csv"

train_data = pd.read_csv(train_file)
valid_data = pd.read_csv(valid_file)

In [None]:
#train_data
train_data.drop('Unnamed: 0', axis = 1)
valid_data.drop('Unnamed: 0', axis = 1)


Unnamed: 0,orig,target,title
0,SECTION 1. ENVIRONMENTAL INFRASTRUCTURE.\n\n ...,Amends the Water Resources Development Act of ...,To make technical corrections to the Water Res...
1,That this Act may be cited as the ``Federal Fo...,Federal Forage Fee Act of 1993 - Subjects graz...,Federal Forage Fee Act of 1993
2,SECTION 1. SHORT TITLE.\n\n This Act may be...,. Merchant Marine of World War II Congression...,Merchant Marine of World War II Congressional ...
3,SECTION 1. SHORT TITLE.\n\n This Act may be...,Small Business Modernization Act of 2004 - Ame...,To amend the Internal Revenue Code of 1986 to ...
4,SECTION 1. SHORT TITLE.\n\n This Act may be...,Fair Access to Investment Research Act of 2016...,Fair Access to Investment Research Act of 2016
...,...,...,...
3264,SECTION 1. PLACEMENT PROGRAMS FOR FEDERAL EMPL...,Public Servant Priority Placement Act of 1995 ...,Public Servant Priority Placement Act of 1995
3265,SECTION 1. SHORT TITLE.\n\n This Act may be...,Sportsmanship in Hunting Act of 2008 - Amends ...,"A bill to amend title 18, United States Code, ..."
3266,SECTION 1. SHORT TITLE.\n\n This Act may be...,Helping College Students Cross the Finish Line...,Helping College Students Cross the Finish Line...
3267,SECTION 1. SHORT TITLE.\n\n This Act may be...,Makes proceeds from such conveyances available...,Texas National Forests Improvement Act of 2000


In [None]:
# preprocess function
max_length = 168
prefix = 'summarize: '

def preprocess_data(text_pairs, tokenizer, model, max_length= max_length):
    orig_text = [prefix + orig for orig, target in text_pairs]
    orig_encoded = tokenizer.batch_encode_plus(
        orig_text,
        max_length=max_length,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        return_tensors='tf'
    )

    orig_input_ids = np.array(orig_encoded["input_ids"], dtype="int32")
    orig_attention_masks = np.array(orig_encoded["attention_mask"], dtype="int32")

    target_text = [target for orig, target in text_pairs]
    target_encoded = tokenizer.batch_encode_plus(
        target_text,
        max_length=max_length,
        padding='max_length',
        truncation=True,
        return_tensors='tf'
    )

    label_ids = np.array(target_encoded['input_ids'])
    decoder_input_ids = model._shift_right(label_ids)

    return [orig_input_ids, orig_attention_masks, decoder_input_ids], label_ids

####################


In [None]:

class DataGenerator(tf.keras.utils.Sequence):

    def __init__(self,
                 tokenizer,
                 model,
                 n_examples,
                 data_filename,
                 max_length=128,
                 batch_size=16,
                 shuffle=True):

        self.tokenizer = tokenizer
        self.model = model
        self.n_examples = n_examples
        self.data_filename = data_filename
        self.max_length = max_length
        self.batch_size = batch_size
        self.shuffle = shuffle

        # Initialize row order, call on_epoch_end to shuffle row indices
        self.row_order = np.arange(1, self.n_examples+1)
        self.on_epoch_end()

    def __len__(self):
        # Return the number of batches in the full dataset
        return self.n_examples // self.batch_size

    def __getitem__(self, idx):
        batch_start = idx * self.batch_size
        batch_end = (idx + 1) * self.batch_size

        # Indices to skip are the ones in the shuffled row_order before and
        # after the chunk we'll use for this batch
        batch_idx_skip = self.row_order[:batch_start] + self.row_order[batch_end:]
        df = pd.read_csv(self.data_filename, skiprows=batch_idx_skip)

        text_pairs = df[['orig', 'target']].values.astype(str).tolist()

        batch_data = preprocess_data(
            text_pairs,
            self.tokenizer,
            self.model,
            self.max_length
        )

        return batch_data

    def __call__(self):
        for i in range(self.__len__()):
            yield self.__getitem__(i)

            if i == self.__len__()-1:
                self.on_epoch_end()

    def on_epoch_end(self):
        if self.shuffle:
            self.row_order = list(np.random.permutation(self.row_order))


In [None]:
model_name = 't5-base'
tokenizer = AutoTokenizer.from_pretrained(model_name)
t5_model = TFT5ForConditionalGeneration.from_pretrained(model_name)

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.
All PyTorch model weights were used when initializing TFT5ForConditionalGeneration.

All the weights of TFT5ForConditionalGeneration were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.


In [None]:
batch_size = 8

train_data_generator = DataGenerator(
    tokenizer=tokenizer,
    model=t5_model,
    n_examples=len(train_data),
    data_filename=train_file,
    max_length=max_length,
    batch_size=batch_size
)

valid_data_generator = DataGenerator(
    tokenizer=tokenizer,
    model=t5_model,
    n_examples=len(valid_data),
    data_filename=valid_file,
    max_length=max_length,
    batch_size=batch_size
)

In [None]:
def build_t5_training_wrapper_model(my_t5, max_length):
    input_ids = layers.Input(shape=(max_length), dtype=tf.int32, name='input_ids')
    attention_mask = layers.Input(shape=(max_length), dtype=tf.int32, name='attention_mask')
    decoder_input_ids = layers.Input(shape=(max_length), dtype=tf.int32, name='labels')

    t5_logits = my_t5(input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids)[0]

    model = tf.keras.models.Model(inputs=[input_ids, attention_mask, decoder_input_ids],
                                  outputs=[t5_logits])
    model.compile(optimizer=tf.keras.optimizers.Adam(),
                  loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True))

    return model

In [None]:
model_wrapper = build_t5_training_wrapper_model(t5_model, max_length)

In [None]:
checkpoint_dir = 'drive/MyDrive/266project/model_checkpoints/'
checkpoint_filepath = checkpoint_dir + 't5_billsum_weights.{epoch:02d}.hdf5'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True)

In [None]:
model_wrapper.fit(x = train_data_generator,
                  validation_data=valid_data_generator,
                  epochs=3,
                  callbacks=[model_checkpoint_callback])

In [None]:
# check that it works to load model weights
model_wrapper.load_weights(checkpoint_dir + 't5_billsum_weights.03.hdf5')