Final Project - Gregory LeMasurier and Mojtaba Talaei Khoei

Making the training file a jupyter notebook for the time being so I can easily debug it.

In [None]:
# Install Dependencies
import sys
!{sys.executable} -m pip install rouge-score nltk sentencepiece

In [None]:
# Common Imports
import os
import random

import transformers
from transformers import PegasusTokenizer, PegasusConfig
from transformers import PegasusForConditionalGeneration

import datasets
from datasets import load_dataset

import torch
from torch.utils.data import DataLoader

import wandb
from packaging import version
from tqdm.auto import tqdm


%load_ext autoreload
%autoreload 2

sys.path.append('../')

from transformer_mt import utils

In [None]:
# Setup logging
import logging

logger = logging.getLogger("Summarization")
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)

datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_warning()

In [None]:
# ROUGE Metric
rouge = datasets.load_metric("rouge")

In [None]:
cpu_only = True

dataset_name = 'cnn_dailymail'
dataset_version = '3.0.0'
wandb_project = "PegasusSummarization"
output_dir = "output_dir/"
device = 'cuda' if (torch.cuda.is_available() and not cpu_only) else 'cpu'

if torch.cuda.is_available:
    torch.cuda.empty_cache()

model_name = 'google/pegasus-xsum' 
tokenizer_name = 'google/pegasus-cnn_dailymail'
seq_len = 1024
batch_size = 8
learning_rate = 5e-5
weight_decay = 0.0
num_train_epochs = 10
lr_scheduler_type = "linear"
num_warmup_steps = 0
eval_every_steps = 2000
k = int(seq_len * 0.3)

# Flag to make 
debug = False

In [None]:
def main():
    logger.info(f"Starting tokenizer training")

    logger.info(f"Loading dataset")

    wandb.init(project=wandb_project) #Skipping config for now - will add back later

    os.makedirs(output_dir, exist_ok=True)

    raw_datasets = load_dataset(dataset_name, dataset_version)

    # Make a small dataset for proof of concept
    if debug:
        raw_datasets = utils.sample_small_debug_dataset(raw_datasets)

    ## TOKENIZER
    tokenizer = PegasusTokenizer.from_pretrained(tokenizer_name)
    
    ## PRETRAINED MODEL
    #The pegasus model is too large to test on a laptop, so load a small config for now
    #model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device)
    config = PegasusConfig(
            encoder_layers=2, 
            decoder_layers=2, 
            encoder_attention_heads=8, 
            decoder_attention_heads=8, 
            decoder_ffn_dim=1024, 
            encoder_ffn_dim=1024,
            max_position_embeddings=seq_len,
            vocab_size=tokenizer.vocab_size
            )
    model = PegasusForConditionalGeneration(config).to(device)

    column_names = raw_datasets["train"].column_names

    def tokenize_function(examples):
        inputs = [ex for ex in examples['article']]
        targets = [ex for ex in examples['highlights']]
        model_inputs = tokenizer(inputs, max_length=seq_len, truncation=True)
        model_inputs['labels'] = tokenizer(targets, max_length=seq_len, truncation=True)['input_ids']
        return model_inputs

    tokenized_datasets = raw_datasets.map(
        tokenize_function,
        batched=True,
        num_proc=8,
        remove_columns=column_names,
        load_from_cache_file=True,
        desc="Tokenizing the dataset",
    )

    train_dataset = tokenized_datasets["train"]
    eval_dataset = tokenized_datasets["validation"] if "validaion" in tokenized_datasets else tokenized_datasets["test"]
    test_dataset = tokenized_datasets["test"]


    for index in random.sample(range(len(train_dataset)), 2):
        logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
        logger.info(f"Sample {index} of the training set input ids: {train_dataset[index]['input_ids']}.")
        logger.info(f"Decoded input_ids: {tokenizer.decode(train_dataset[index]['input_ids'])}")
        logger.info(f"Decoded labels: {tokenizer.decode(train_dataset[index]['labels'])}")
        logger.info("\n")

    collator = transformers.DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, max_length=seq_len, padding='max_length', label_pad_token_id=0)

    train_dataloader = DataLoader(
        train_dataset, 
        shuffle=True, 
        collate_fn=collator, 
        batch_size=batch_size
    )
    
    eval_dataloader = DataLoader(
        eval_dataset, 
        collate_fn=collator, 
        batch_size=batch_size
    )

    test_dataloader = DataLoader(
        test_dataset, 
        collate_fn=collator, 
        batch_size=batch_size
    )
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=learning_rate,
        weight_decay=weight_decay,
    )
    
    num_update_steps_per_epoch = len(train_dataloader)
    max_train_steps = num_train_epochs * num_update_steps_per_epoch

    lr_scheduler = transformers.get_scheduler(
        name=lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=max_train_steps,
    )

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {num_train_epochs}")
    logger.info(f"  Total optimization steps = {max_train_steps}")
    progress_bar = tqdm(range(max_train_steps))

    batch = next(iter(train_dataloader))

    global_step = 0
    for epoch in range(num_train_epochs):
        model.train()
        for batch in train_dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            out = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )

            loss = out["loss"]
            logits = out["logits"]
            res = torch.topk(logits, k=k)
            values = res[0]

            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            progress_bar.update(1)
            global_step += 1

            wandb.log(
                {
                    "train_loss": loss,
                    "learning_rate": optimizer.param_groups[0]["lr"],
                    "epoch": epoch,
                },
                step=global_step,
            )

            if (global_step % eval_every_steps == 0) or (global_step >= max_train_steps):
                model.eval()

                generations = []
                eval_labels = []
                for batch in eval_dataloader:
                    eval_input_ids = batch["input_ids"].to(device)
                    eval_labels.append(batch["labels"].to(device))
                    encoded_summary = model.generate(eval_input_ids)
                    generations.append(encoded_summary)
                    #print(generations)
                    #decodings = tokenizer.batch_decode(encoded_summary, skip_special_tokens=True, clean_up_tokenization_spaces=False)
                    #print("Decoding: " + str(decodings))

                rouge_score = rouge.compute(predictions=generations, references=eval_labels)

                metric = {}
                for rouge_type in rouge_score:
                    metric['eval/' + rouge_type + "/precision"] = rouge_score[rouge_type][0][0]
                    metric['eval/' + rouge_type + "/recall"] = rouge_score[rouge_type][0][1]
                    metric['eval/' + rouge_type + "/f1-score"] = rouge_score[rouge_type][0][2]

                wandb.log(metric, step=global_step)

                logger.info("Saving model checkpoint to %s", output_dir)
                model.save_pretrained(output_dir)

                model.train()

            if global_step >= max_train_steps:
                break
    summaries = []
    test_labels = []
    for batch in test_dataloader:
        test_input_ids = batch["input_ids"].to(device)
        test_labels.append(batch["labels"].to(device))
        test_encoded_summary = model.generate(test_input_ids)
        summaries.append(test_encoded_summary)
        decoded_summaries = tokenizer.batch_decode(test_encoded_summary, skip_special_tokens=True, clean_up_tokenization_spaces=False)
        print("Summary: " + str(decoded_summaries))
        

In [None]:
if __name__ == "__main__" :
    if version.parse(datasets.__version__) < version.parse("1.18.0"):
        raise RuntimeError("This script requires Datasets 1.18.0 or higher. Please update via pip install -U datasets.")
    main()