Final Project - Gregory LeMasurier and Mojtaba Talaei Khoei

Making the training file a jupyter notebook for the time being so I can easily debug it.

In [1]:
# Install Dependencies
import sys
!{sys.executable} -m pip install rouge-score nltk sentencepiece



In [2]:
# Common Imports
import os
import random

import transformers
from transformers import PegasusTokenizer, PegasusConfig
from transformers import PegasusForConditionalGeneration

import datasets
from datasets import load_dataset

import torch
from torch.utils.data import DataLoader

import wandb
from packaging import version
from tqdm.auto import tqdm


%load_ext autoreload
%autoreload 2

sys.path.append('../')

from transformer_mt import utils

In [3]:
# Setup logging
import logging

logger = logging.getLogger("Summarization")
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)

datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_warning()

In [4]:
# ROUGE Metric
rouge = datasets.load_metric("rouge")

In [5]:
cpu_only = True

dataset_name = 'cnn_dailymail'
dataset_version = '3.0.0'
wandb_project = "PegasusSummarization"
output_dir = "output_dir/"
device = 'cuda' if (torch.cuda.is_available() and not cpu_only) else 'cpu'

if torch.cuda.is_available:
    torch.cuda.empty_cache()

model_name = 'google/pegasus-xsum' 
tokenizer_name = 'google/pegasus-cnn_dailymail'
seq_len = 512
batch_size = 8
learning_rate = 5e-5
weight_decay = 0.0
num_train_epochs = 1
lr_scheduler_type = "linear"
num_warmup_steps = 0
eval_every_steps = 5
k = int(512 * 0.3)

# Flag to make 
debug = True

In [6]:
def main():
    logger.info(f"Starting tokenizer training")

    logger.info(f"Loading dataset")

    wandb.init(project=wandb_project) #Skipping config for now - will add back later

    os.makedirs(output_dir, exist_ok=True)

    raw_datasets = load_dataset(dataset_name, dataset_version)

    # Make a small dataset for proof of concept
    if debug:
        raw_datasets = utils.sample_small_debug_dataset(raw_datasets)

    ## TOKENIZER
    tokenizer = PegasusTokenizer.from_pretrained(tokenizer_name)

    ## PRETRAINED MODEL
    #The pegasus model is too large to test on a laptop, so load a small config for now
    #model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device)
    config = PegasusConfig(
            encoder_layers=2, 
            decoder_layers=2, 
            encoder_attention_heads=8, 
            decoder_attention_heads=8, 
            decoder_ffn_dim=1024, 
            encoder_ffn_dim=1024,
            max_position_embeddings=seq_len,
            vocab_size=tokenizer.vocab_size
            )
    model = PegasusForConditionalGeneration(config).to(device)

    column_names = raw_datasets["train"].column_names

    def tokenize_function(examples):
        inputs = [ex for ex in examples['article']]
        targets = [ex for ex in examples['highlights']]
        model_inputs = tokenizer(inputs, max_length=seq_len, truncation=True)
        model_inputs['labels'] = tokenizer(targets, max_length=seq_len, truncation=True)['input_ids']
        return model_inputs

    tokenized_datasets = raw_datasets.map(
        tokenize_function,
        batched=True,
        num_proc=8,
        remove_columns=column_names,
        load_from_cache_file=True,
        desc="Tokenizing the dataset",
    )

    train_dataset = tokenized_datasets["train"]
    eval_dataset = tokenized_datasets["validation"] if "validaion" in tokenized_datasets else tokenized_datasets["test"]

    for index in random.sample(range(len(train_dataset)), 2):
        logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
        logger.info(f"Sample {index} of the training set input ids: {train_dataset[index]['input_ids']}.")
        logger.info(f"Decoded input_ids: {tokenizer.decode(train_dataset[index]['input_ids'])}")
        logger.info(f"Decoded labels: {tokenizer.decode(train_dataset[index]['labels'])}")
        logger.info("\n")

    #collator = transformers.DataCollatorWithPadding(tokenizer=tokenizer, max_length=seq_len, padding='max_length')
    collator = transformers.DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, max_length=seq_len, padding='max_length')

    train_dataloader = DataLoader(
        train_dataset, 
        shuffle=True, 
        collate_fn=collator, 
        batch_size=batch_size
    )
    
    eval_dataloader = DataLoader(
        eval_dataset, 
        collate_fn=collator, 
        batch_size=batch_size
    )

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=learning_rate,
        weight_decay=weight_decay,
    )
    
    # Scheduler and math around the number of training steps.
    num_update_steps_per_epoch = len(train_dataloader)
    max_train_steps = num_train_epochs * num_update_steps_per_epoch

    lr_scheduler = transformers.get_scheduler(
        name=lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=max_train_steps,
    )

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {num_train_epochs}")
    logger.info(f"  Total optimization steps = {max_train_steps}")
    progress_bar = tqdm(range(max_train_steps))

    # Log a pre-processed training example to make sure the pre-processing does not have bugs in it
    # and we do not input garbage to our model.
    batch = next(iter(train_dataloader))

    #logger.info("Look at the data that we input into the model, check that it looks like what we expect.")
    #for index in random.sample(range(len(batch)), 2):
    #    logger.info(f"Decoded input_ids size: {len(batch['input_ids'][index])}")
    #    logger.info(f"Decoded input_ids: {tokenizer.decode(batch['input_ids'][index])}")
    #    logger.info(f"Decoded labels size: {len(batch['labels'][index])}")
    #    logger.info(f"Decoded labels: {tokenizer.decode(batch['labels'][index])}")
    #    logger.info("\n")

    global_step = 0
    for epoch in range(num_train_epochs):
        model.train()
        for batch in train_dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            out = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )

            loss = out["loss"]
            logits = out["logits"]
            res = torch.topk(logits, k=k)
            values = res[0]
            #print(values)
            #print("\n\n")
            #print(indices)

            #print(loss.item())
            #print("\n\n")
            #print(logits)
            #print("\n\n")
            #print(labels)

            #tokenizer.decode()

            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            progress_bar.update(1)
            global_step += 1

            wandb.log(
                {
                    "train_loss": loss,
                    "learning_rate": optimizer.param_groups[0]["lr"],
                    "epoch": epoch,
                },
                step=global_step,
            )

            if (global_step % eval_every_steps == 0) or (global_step >= max_train_steps):
                model.eval()

                #TODO: USING SAME VALUE FOR PREDICTION AND REFERENCE!!!!
                for text in labels:
                    print(text)
                    print("\n")
                    print(text.item())
                    print("\n\n")
                #    print( "SUMMARY: " + str(tokenizer.decode(text)))

                rouge_score = rouge.compute(predictions=values, references=labels)

                metric = {}
                for rouge_type in rouge_score:
                    metric['eval/' + rouge_type + "/precision"] = rouge_score[rouge_type][0][0]
                    metric['eval/' + rouge_type + "/recall"] = rouge_score[rouge_type][0][1]
                    metric['eval/' + rouge_type + "/f1-score"] = rouge_score[rouge_type][0][2]

                wandb.log(metric, step=global_step)

                logger.info("Saving model checkpoint to %s", output_dir)
                model.save_pretrained(output_dir)

                model.train()

            if global_step >= max_train_steps:
                break

In [7]:
if __name__ == "__main__" :
    if version.parse(datasets.__version__) < version.parse("1.18.0"):
        raise RuntimeError("This script requires Datasets 1.18.0 or higher. Please update via pip install -U datasets.")
    main()

04/21/2022 16:58:50 - INFO - Summarization - Starting tokenizer training
04/21/2022 16:58:50 - INFO - Summarization - Loading dataset
04/21/2022 16:58:50 - ERROR - wandb.jupyter - Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mglemasurier[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.15 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


100%|██████████| 3/3 [00:00<00:00, 600.39it/s]

[A


Tokenizing the dataset #0:   0%|          | 0/1 [00:00<?, ?ba/s]

[A[A



Tokenizing the dataset #2: 100%|██████████| 1/1 [00:00<00:00, 10.46ba/s]
Tokenizing the dataset #4: 100%|██████████| 1/1 [00:00<00:00, 11.91ba/s]






[A[A[A[A[A[A




Tokenizing the dataset #0: 100%|██████████| 1/1 [00:00<00:00, 11.46ba/s]
Tokenizing the dataset #1: 100%|██████████| 1/1 [00:00<00:00, 12.34ba/s]
Tokenizing the dataset #3: 100%|██████████| 1/1 [00:00<00:00, 10.94ba/s]
Tokenizing the dataset #7: 100%|██████████| 1/1 [00:00<00:00, 12.63ba/s]




Tokenizing the dataset #5: 100%|██████████| 1/1 [00:00<00:00,  8.30ba/s]
Tokenizing the dataset #6: 100%|██████████| 1/1 [00:00<00:00, 10.35ba/s]
04/21/2022 16:59:04 - INFO - Summarization - Sample 68 of the training set: {'input_ids': [5300, 131, 116, 5469, 112, 1713, 109, 34372, 34294, 3164, 117, 124, 109, 6412, 204, 2084, 160, 109, 849, 107, 5300, 127, 6674, 464, 9464, 108, 9284, 108, 2579, 10

tensor([33116,  7716,   112,  2094, 13966,  7414,   124,  1408,  1428,   115,
         4600,  4354,   110,   107,  7716,  1155,   114,  1082,   124,  2277,
          113,   342,   458,   114,  4235,   110,   107,   240,   178, 10074,
         7414,   108,   178,   131,   267,  1762, 37362, 53729,   132, 20340,
        53132,   110,   107,     1,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100])
tensor([43086,  7784,   133,   174,  1749,  1537, 17518,  1573,   113,   177,
          578,   110,   107,  9493, 24106,  8990,  5908, 90937,

 77%|███████▋  | 10/13 [00:42<00:12,  4.32s/it]

tensor([ 7026,   121, 91462, 56816,   123,   116,  3669,  2120,   342,   164,
          110,   107,   112,   129,  1276,   113,  2481,   115,  1326,   233,
          173,   178,   117,   770,   112,   275,   693,   121,   497,   121,
         4801,   110,   107,   122, 42785, 74721,   110,   107, 56816,  2737,
          169,   571,   154,  6568,   110,   107,  8379, 42785, 11065,   661,
          141,   371, 79483,   446,  2978,   113,   114,  9485,  2974,   110,
          107,  4064, 23963,   549,   113,   109,   475,  6719,   829,   110,
          107,     1,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100])
tensor([ 6015,  7590,   117,  1392,  1937,   893, 11023,  2511,   482,   109,
          278,   110,   107,  8497,   131,   116, 16877,   563,   111,  5060,
          131,   116,  2178,  3285,   127,   228,  4964,  2724,   110,   107,
        45475, 67346,   116, 24753, 47535,   115,   114,   405,   111,

04/21/2022 16:59:47 - INFO - Summarization - Saving model checkpoint to output_dir/
100%|██████████| 13/13 [00:56<00:00,  4.22s/it]04/21/2022 17:00:00 - INFO - Summarization - Saving model checkpoint to output_dir/


tensor([ 5759,  1167,   304,  1846, 23999,   464,  8320,   819, 15225,   131,
         1323,  1372,   110,   107, 27330,   113,   335,  1034, 21365,   116,
         5222,   135,  1532,   406,   112,   813, 10255,   131,   222,  1185,
         5539,  2662,  2365,   211,  1034, 11182,  3959,   131,   118,  3880,
          110,   107,  2503,  3921,   112,   403,  1044,   199,   210, 33136,
          116,  1798,   464,  6949,   110,   107,     1,  -100,  -100,  -100,
         -100,  -100,  -100])
tensor([ 2471,   113,  5137, 36267,  6381,  1084,  1775,  1424,   112,   179,
          896,  9390,   110,   107,  1006,  1601, 21561,  1084,  1775,   635,
         1165,  4873,   111, 19440,   112, 42851,   116,   110,   107, 18119,
         1758,  2346, 79297,   243,  3148, 36624,   140,   146,  7360,   110,
          107,  6158,   374,   115,  1816,  1726,   113,  4603, 32053,   143,
          788, 57454,   158,  8740,   110,   107,     1,  -100,  -100,  -100,
         -100,  -100,  -100])
tens

100%|██████████| 13/13 [00:58<00:00,  4.49s/it]
