### Pegasus

#### Further pre-train Pegasus using Gap Sentences Generation (GSG) with process data
- Experiment:
    - Pegasus is further pre-trained using *process data* with its original pre-training task, GSG learning, reproducing masked sentences 
- Process data:
    - Document (masked process text) and summary (masked sentences)
- Input model (one of below):
    - Further pre-trained Pegasus (Pegasus-TML)
    - Pre-trained Pegasus loaded from Hugging Face
- Outline:
    - Track the experiment and its results with WandB (Weights & Biases)
    - Define the experiment, data loading, training and validation 
    - Validation loss is tracked to apply early stopping and prevent overfitting
    
#### Reference
- Pegasus Hugging Face: 
https://huggingface.co/docs/transformers/model_doc/pegasus
- Hugging Face Fine-tuning Transformer tutorial:
https://huggingface.co/docs/transformers/training
- WandB pipeline:
https://colab.research.google.com/github/wandb/examples/blob/master/colabs/pytorch/Simple_PyTorch_Integration.ipynb#scrollTo=FH61NWlVR_SL
- Early stopping:
https://wandb.ai/ayush-thakur/huggingface/reports/Early-Stopping-in-HuggingFace-Examples--Vmlldzo0MzE2MTM

#### Environment Setup 
- Google Colab
- Amazon SageMaker Studio (Kernel: Python 3 Data Science)

In [1]:
# %%capture
# !pip3 install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
# !pip3 install transformers
# !pip3 install sentencepiece
# !pip3 install wandb --upgrade

#### Import Libraries

In [2]:
import json
import random
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import PegasusForConditionalGeneration, PegasusTokenizerFast
from transformers.optimization import Adafactor
from tqdm.auto import tqdm

import wandb

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


Moving 0 files to the new cache system


0it [00:00, ?it/s]

#### WandB Setup

In [3]:
# Ensure deterministic behavior
torch.backends.cudnn.deterministic = True
random.seed(hash("setting random seeds") % 2**32 - 1)
np.random.seed(hash("improves reproducibility") % 2**32 - 1)
torch.manual_seed(hash("by removing stochasticity") % 2**32 - 1)
torch.cuda.manual_seed_all(hash("so runs are repeatable") % 2**32 - 1)

In [4]:
# Device configuration
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
# set up the project in your WandB account first
wandb.login()

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

#### Define the Experiment and Pipeline 

In [6]:
# Define the configuration of the experiment
config = dict(
    epochs = 10,
    batch_size = 2,
    optimizer = "adafactor",
    es_patience = 5, # early stopping patience steps
    loss_function = "maskedSent-loss", # Masked Sentences loss from Gap Sentences Generation (GSG) learning
    dataset = "bpmai-29-10-2019",
    architecture = "seq2seq-pegasus",
    retrain = False, # True if continue training from checkpoint of previous iteration
    input_model = "", # specify path of input model if continue training or left blank
    output_model= ""  # specify path to save output model, i.e. "./model_maskedSent/maskedSent_{}_epoch.pth".
)


##### Track metadata and hyperparameters with wandb.init

In [7]:
# Define the training pipeline
def model_pipeline(hyperparameters):
    with wandb.init(project="wandb-project-name", entity="wandb-entity-name", config=hyperparameters):
        config = wandb.config
        # set model, data loaders, optimizer, and early stopping with defined config
        model, train_loader, val_loader, optimizer = make(config)
        es = EarlyStopping(patience = config.es_patience)
        # train and validate with early stopping applied
        train_and_val(model, train_loader, val_loader, optimizer, es, config)

    return model

##### Set model, data loaders and optimizer with defined configuration

In [8]:
def make(config):
    # set pretrained tokenizer, model and optimizer
    model_name = 'google/pegasus-large' # 'google/pegasus-xsum'
    tokenizer = PegasusTokenizerFast.from_pretrained(model_name)
    model = PegasusForConditionalGeneration.from_pretrained(model_name, return_dict=True)
    optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
    
    # if continue training from checkpoint of previous iteration
    if config.retrain: 
        load(model, optimizer, config.input_model)
        model = PegasusForConditionalGeneration.from_pretrained(model_name, output_hidden_states=True, output_attentions=True, return_dict=True)
        if torch.cuda.device_count() > 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            model = nn.DataParallel(model)
    model.to(device)
    
    # set data loaders
    train_loader = make_loader(train_data, tokenizer, shuffle=True, batch_size=config.batch_size)
    val_loader = make_loader(val_data, tokenizer, shuffle=True, batch_size=config.batch_size)
    # test print data
    for batch in train_loader:
        break
    print({k: v.shape for k, v in batch.items()})
    
    return model, train_loader, val_loader, optimizer

#### Define Data Loading
##### Load data

In [9]:
with open('./data/masked_sent_train.json', 'r') as f:
    train_data = json.load(f)
with open('./data/masked_sent_val.json', 'r') as f:
    val_data = json.load(f)

##### Define Process Dataset

In [10]:
class ProcessDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels
        
    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels['input_ids'][idx])
        return item # input_ids, attention_mask, labels
    
    def __len__(self):
        return len(self.labels['input_ids'])

##### Define Process Data Loader

In [11]:
def make_loader(dataset, tokenizer, shuffle, batch_size):
    texts = [x.lower() for x in dataset['document']]
    labels = [x.lower() for x in dataset['summary']]
    process_dataset = process_data(texts, labels, tokenizer)
    process_dataloader = DataLoader(
        dataset=process_dataset, shuffle=shuffle, batch_size=batch_size
    )
    return process_dataloader

In [12]:
# Define function needed to tokenize texts and labels
def process_data(texts, labels, tokenizer):
    encodings = tokenizer(texts, truncation=True, padding=True)
    decodings = tokenizer(labels, truncation=True, padding=True)
    process_dataset = ProcessDataset(encodings, decodings)
    return process_dataset

#### Define Early Stopping 


In [13]:
class EarlyStopping(object):
    def __init__(self, mode='min', min_delta=0, patience=10, percentage=False):
        self.mode = mode
        self.min_delta = min_delta
        self.patience = patience
        self.best = None
        self.num_bad_epochs = 0
        self.is_better = None
        self._init_is_better(mode, min_delta, percentage)

        if patience == 0:
            self.is_better = lambda a, b: True
            self.step = lambda a: False

    def step(self, metrics):
        if self.best is None:
            self.best = metrics
            return False

        if torch.isnan(metrics):
            return True

        if self.is_better(metrics, self.best):
            self.num_bad_epochs = 0
            self.best = metrics
        else:
            self.num_bad_epochs += 1

        if self.num_bad_epochs >= self.patience:
            return True

        return False

    def _init_is_better(self, mode, min_delta, percentage):
        if mode not in {'min', 'max'}:
            raise ValueError('mode ' + mode + ' is unknown!')
        if not percentage:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - min_delta
            if mode == 'max':
                self.is_better = lambda a, best: a > best + min_delta
        else:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - (
                            best * min_delta / 100)
            if mode == 'max':
                self.is_better = lambda a, best: a > best + (
                            best * min_delta / 100)

#### Define Training Logic
##### Track gradients and weights with wandb.watch and everything else, i.e. loss, with wandb.log

In [14]:
def train_and_val(model, train_loader, val_loader, optimizer, es, config):
    # set the model to train
    wandb.watch(model, log="all", log_freq=10)

    # run training and track with wandb
    total_batches = len(train_loader) * config.epochs
    print('num_training_steps', total_batches)
    progress_bar = tqdm(range(total_batches))

    batch_ct = 0
    running_loss = 0.
    last_loss = 0.
    model_save_epoch = 0
    for epoch in range(config.epochs):
        model.train()
        for idx, process_batch in enumerate(train_loader):
            loss = train_batch(idx, process_batch, model, optimizer, progress_bar)
            batch_ct += 1
            # report metrics every 25th batch
            running_loss += loss.item()
            if (batch_ct % 25) == 0:
                last_loss = running_loss / 25 # log loss in average term
                train_log(last_loss, batch_ct, epoch)
                running_loss = 0.

        # validate model after train at each epoch
        model.eval()
        val_loss = val(model, val_loader)
        val_log(val_loss, batch_ct, epoch) # log validation loss
        # save model after train each epoch
        output_model = config.output_model.format(epoch+1)
        save(model, optimizer, output_model)
        # check whether to apply early stopping (number of patience step)
        if es.step(val_loss):
            break
            

##### Define functions needed in the training loop

In [15]:
def train_batch(idx, batch, model, optimizer, progress_bar):                                                                                 
    process_item = {k: v.to(device) for k, v in batch.items()}        
    # forward pass
    model_output = model(**process_item)
    loss = model_output.loss    
    # backward pass
    optimizer.zero_grad()
    loss.backward()
    # step with optimizer every 2 step (batch accumulation)
    if (idx+1) % 2 == 0:
        optimizer.step()
        progress_bar.update(1)

    return loss

In [16]:
def val(model, val_loader):
    with torch.no_grad():
        loss = 0
        for _, process_batch in enumerate(val_loader):
            process_item = {k: v.to(device) for k, v in process_batch.items()}
            outputs = model(**process_item)
            loss += outputs.loss 
        # output loss in average
        loss /= len(val_loader)
    
    return loss

In [17]:
def train_log(loss, batch_num, epoch):
    wandb.log({"epoch": epoch, "loss": loss}, step=batch_num)
    print(f"Loss after " + str(batch_num).zfill(5) + f" steps: {loss:.3f}")

def val_log(loss, batch_num, epoch):
    wandb.log({"val_loss": loss})
    print(f"Validation Loss after " + str(batch_num).zfill(5) + f" training steps: {loss:.3f}")

In [18]:
def save(model, optimizer, output_model):
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, output_model)

def load(model, optimizer, output_model):
    checkpoint = torch.load(output_model)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

#### Build, train and analyze the model with the pipeline

In [None]:
model = model_pipeline(config)

{'input_ids': torch.Size([2, 809]), 'attention_mask': torch.Size([2, 809]), 'labels': torch.Size([2, 92])}
num_training_steps 8530


  0%|          | 0/8530 [00:00<?, ?it/s]

Loss after 00025 steps: 12.003
Loss after 00050 steps: 11.684
Loss after 00075 steps: 11.170
Loss after 00100 steps: 11.350
Loss after 00125 steps: 10.939
Loss after 00150 steps: 10.642
Loss after 00175 steps: 10.397
Loss after 00200 steps: 10.061
Loss after 00225 steps: 9.847
Loss after 00250 steps: 9.786
Loss after 00275 steps: 9.482
Loss after 00300 steps: 9.352
Loss after 00325 steps: 9.083
Loss after 00350 steps: 8.248
Loss after 00375 steps: 7.016
Loss after 00400 steps: 4.296
Loss after 00425 steps: 1.886
Loss after 00450 steps: 0.845
Loss after 00475 steps: 0.597
Loss after 00500 steps: 0.560
Loss after 00525 steps: 0.603
Loss after 00550 steps: 0.464
Loss after 00575 steps: 0.640
Loss after 00600 steps: 0.516
Loss after 00625 steps: 0.555
Loss after 00650 steps: 0.575
Loss after 00675 steps: 0.809
Loss after 00700 steps: 0.447
Loss after 00725 steps: 0.415
Loss after 00750 steps: 0.574
Loss after 00775 steps: 0.466
Loss after 00800 steps: 0.336
Loss after 00825 steps: 0.452
Lo