Final Project - Gregory LeMasurier and Mojtaba Talaei Khoei

Making the training file a jupyter notebook for the time being so I can easily debug it.

In [1]:
# Install Dependencies
import sys
!{sys.executable} -m pip install rouge-score nltk sentencepiece



In [2]:
# Common Imports
import os
import random

import transformers
from transformers import PegasusTokenizer, PegasusConfig
from transformers import PegasusForConditionalGeneration

import datasets
from datasets import load_dataset

import torch
from torch.utils.data import DataLoader

import wandb
from packaging import version
from tqdm.auto import tqdm


%load_ext autoreload
%autoreload 2

sys.path.append('../')

from transformer_mt import utils

In [3]:
# Setup logging
import logging

logger = logging.getLogger("Summarization")
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)

datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_warning()

In [4]:
# ROUGE Metric
rouge = datasets.load_metric("rouge")

In [5]:
cpu_only = True

dataset_name = 'cnn_dailymail'
dataset_version = '3.0.0'
wandb_project = "PegasusSummarization"
output_dir = "output_dir/"
device = 'cuda' if (torch.cuda.is_available() and not cpu_only) else 'cpu'

if torch.cuda.is_available:
    torch.cuda.empty_cache()

model_name = 'google/pegasus-xsum' 
tokenizer_name = 'google/pegasus-cnn_dailymail'
seq_len = 512
batch_size = 8
learning_rate = 5e-5
weight_decay = 0.0
num_train_epochs = 1
lr_scheduler_type = "linear"
num_warmup_steps = 0
eval_every_steps = 5

# Flag to make 
debug = True

In [6]:
def main():
    logger.info(f"Starting tokenizer training")

    logger.info(f"Loading dataset")

    wandb.init(project=wandb_project) #Skipping config for now - will add back later

    os.makedirs(output_dir, exist_ok=True)

    raw_datasets = load_dataset(dataset_name, dataset_version)

    # Make a small dataset for proof of concept
    if debug:
        raw_datasets = utils.sample_small_debug_dataset(raw_datasets)

    ## TOKENIZER
    tokenizer = PegasusTokenizer.from_pretrained(tokenizer_name)

    ## PRETRAINED MODEL
    #The pegasus model is too large to test on a laptop, so load a small config for now
    #model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device)
    config = PegasusConfig(
            encoder_layers=2, 
            decoder_layers=2, 
            encoder_attention_heads=8, 
            decoder_attention_heads=8, 
            decoder_ffn_dim=1024, 
            encoder_ffn_dim=1024,
            max_position_embeddings=seq_len,
            vocab_size=tokenizer.vocab_size
            )
    model = PegasusForConditionalGeneration(config).to(device)

    column_names = raw_datasets["train"].column_names

    def tokenize_function(examples):
        inputs = [ex for ex in examples['article']]
        targets = [ex for ex in examples['highlights']]
        model_inputs = tokenizer(inputs, max_length=seq_len, truncation=True)
        model_inputs['labels'] = tokenizer(targets, max_length=seq_len, truncation=True)['input_ids']
        return model_inputs

    tokenized_datasets = raw_datasets.map(
        tokenize_function,
        batched=True,
        num_proc=8,
        remove_columns=column_names,
        load_from_cache_file=True,
        desc="Tokenizing the dataset",
    )

    train_dataset = tokenized_datasets["train"]
    eval_dataset = tokenized_datasets["validation"] if "validaion" in tokenized_datasets else tokenized_datasets["test"]

    for index in random.sample(range(len(train_dataset)), 2):
        logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
        logger.info(f"Decoded input_ids: {tokenizer.decode(train_dataset[index]['input_ids'])}")
        logger.info(f"Decoded labels: {tokenizer.decode(train_dataset[index]['labels'])}")
        logger.info("\n")

    #collator = transformers.DataCollatorWithPadding(tokenizer=tokenizer, max_length=seq_len, padding='max_length')
    collator = transformers.DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, max_length=seq_len, padding='max_length')

    train_dataloader = DataLoader(
        train_dataset, 
        shuffle=True, 
        collate_fn=collator, 
        batch_size=batch_size
    )
    
    eval_dataloader = DataLoader(
        eval_dataset, 
        collate_fn=collator, 
        batch_size=batch_size
    )

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=learning_rate,
        weight_decay=weight_decay,
    )
    
    # Scheduler and math around the number of training steps.
    num_update_steps_per_epoch = len(train_dataloader)
    max_train_steps = num_train_epochs * num_update_steps_per_epoch

    lr_scheduler = transformers.get_scheduler(
        name=lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=max_train_steps,
    )

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {num_train_epochs}")
    logger.info(f"  Total optimization steps = {max_train_steps}")
    progress_bar = tqdm(range(max_train_steps))

    # Log a pre-processed training example to make sure the pre-processing does not have bugs in it
    # and we do not input garbage to our model.
    batch = next(iter(train_dataloader))

    #logger.info("Look at the data that we input into the model, check that it looks like what we expect.")
    #for index in random.sample(range(len(batch)), 2):
    #    logger.info(f"Decoded input_ids size: {len(batch['input_ids'][index])}")
    #    logger.info(f"Decoded input_ids: {tokenizer.decode(batch['input_ids'][index])}")
    #    logger.info(f"Decoded labels size: {len(batch['labels'][index])}")
    #    logger.info(f"Decoded labels: {tokenizer.decode(batch['labels'][index])}")
    #    logger.info("\n")

    global_step = 0
    for epoch in range(num_train_epochs):
        model.train()
        for batch in train_dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            out = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )

            loss = out["loss"]
            logits = out["logits"]
            #TODO: Top K sampling of logits

            #print(loss.item())
            #print("\n\n")
            #print(logits)
            #print("\n\n")
            #print(labels)

            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            progress_bar.update(1)
            global_step += 1

            wandb.log(
                {
                    "train_loss": loss,
                    "learning_rate": optimizer.param_groups[0]["lr"],
                    "epoch": epoch,
                },
                step=global_step,
            )

            if (global_step % eval_every_steps == 0) or (global_step >= max_train_steps):
                model.eval()

                #TODO: USING SAME VALUE FOR PREDICTION AND REFERENCE!!!!
                rouge_score = rouge.compute(predictions=labels, references=labels)

                print(rouge_score)
                print("\n\n") 


                metric = {}
                for rouge_type in rouge_score:
                    print(rouge_score[rouge_type][0])
                    print("\n\n")
                    metric['eval/' + rouge_type + "/precision"] = rouge_score[rouge_type][0][0]
                    metric['eval/' + rouge_type + "/recall"] = rouge_score[rouge_type][0][1]
                    metric['eval/' + rouge_type + "/f1-score"] = rouge_score[rouge_type][0][2]


                print(metric)

                wandb.log(metric, step=global_step)

                logger.info("Saving model checkpoint to %s", output_dir)
                model.save_pretrained(output_dir)

                model.train()

            if global_step >= max_train_steps:
                break

In [7]:
if __name__ == "__main__" :
    if version.parse(datasets.__version__) < version.parse("1.18.0"):
        raise RuntimeError("This script requires Datasets 1.18.0 or higher. Please update via pip install -U datasets.")
    main()

04/19/2022 17:35:29 - INFO - Summarization - Starting tokenizer training
04/19/2022 17:35:29 - INFO - Summarization - Loading dataset
04/19/2022 17:35:29 - ERROR - wandb.jupyter - Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mglemasurier[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.14 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


100%|██████████| 3/3 [00:00<00:00, 24.21it/s]


[A[A


Tokenizing the dataset #0:   0%|          | 0/1 [00:00<?, ?ba/s]





[A[A[A[A[A[A
Tokenizing the dataset #3: 100%|██████████| 1/1 [00:00<00:00, 10.54ba/s]
Tokenizing the dataset #1: 100%|██████████| 1/1 [00:00<00:00,  8.80ba/s]




[A[A[A[A




[A[A[A[A[A


Tokenizing the dataset #4: 100%|██████████| 1/1 [00:00<00:00,  8.44ba/s]
Tokenizing the dataset #7: 100%|██████████| 1/1 [00:00<00:00, 10.19ba/s]
Tokenizing the dataset #0: 100%|██████████| 1/1 [00:00<00:00,  8.48ba/s]

Tokenizing the dataset #2: 100%|██████████| 1/1 [00:00<00:00,  9.31ba/s]
Tokenizing the dataset #5: 100%|██████████| 1/1 [00:00<00:00, 11.45ba/s]
Tokenizing the dataset #6: 100%|██████████| 1/1 [00:00<00:00, 12.01ba/s]
04/19/2022 17:35:42 - INFO - Summarization - Sample 79 of the training set: {'input_ids': [143, 40155, 158, 1315, 3064, 131, 116, 6981, 554, 24315, 13184, 1191, 17212, 40722, 252, 1327, 122, 28189, 111, 50765, 116, 113, 14282, 111

{'rouge1': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0)), 'rouge2': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0)), 'rougeL': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0)), 'rougeLsum': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0))}



Score(precision=1.0, recall=1.0, fmeasure=1.0)





TypeError: tuple indices must be integers or slices, not str