In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [1]:
import pandas as pd
import numpy as np
import nltk

import torch
from torchmetrics.text import ROUGEScore

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer 
from datasets import Dataset, DatasetDict, load_from_disk

import wandb

In [10]:
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"


In [4]:
data_path = "../../../data/labelled/reviews/splits/"

In [5]:
train_df = pd.read_csv(os.path.join(data_path, "train.csv"))
dev_df = pd.read_csv(os.path.join(data_path, "dev.csv"))
test_df = pd.read_csv(os.path.join(data_path, "test.csv"))

In [6]:
train_df.shape, dev_df.shape, test_df.shape

((177323, 2), (10000, 2), (20000, 2))

In [7]:
dataset = DatasetDict()
dataset['train'] = Dataset.from_pandas(train_df, split="train")
dataset['dev'] = Dataset.from_pandas(dev_df, split="dev")
dataset['test'] = Dataset.from_pandas(test_df, split="test")

In [8]:
dataset

DatasetDict({
    train: Dataset({
        features: ['input_reviews', 'label_reviews'],
        num_rows: 177323
    })
    dev: Dataset({
        features: ['input_reviews', 'label_reviews'],
        num_rows: 10000
    })
    test: Dataset({
        features: ['input_reviews', 'label_reviews'],
        num_rows: 20000
    })
})

In [7]:
model_name = "facebook/bart-large-cnn"

save_dir = "bart-reviews-summarization"
os.environ["WANDB_PROJECT"]=f"bart_large_product_reviews"
os.environ["WANDB_LOG_MODEL"]="false"
os.environ["WANDB_WATCH"]="false"
os.environ["WANDB_NOTEBOOK_NAME"]="bart-large.ipynb"

In [8]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

In [9]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [6]:
rouge = ROUGEScore()

In [10]:
max_input_length = 1024
max_target_length = 512

def preprocess_function(examples):
    inputs = [doc for doc in examples["input_reviews"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, padding=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["label_reviews"], max_length=max_target_length, truncation=True, padding=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [15]:
tokenized_datasets = dataset.map(preprocess_function, batched=True)

Map:   0%|          | 0/177323 [00:00<?, ? examples/s]



Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/20000 [00:00<?, ? examples/s]

In [16]:
tokenized_datasets.save_to_disk("bart_tokenized")

Saving the dataset (0/5 shards):   0%|          | 0/177323 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/10000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/20000 [00:00<?, ? examples/s]

In [11]:
tokenized_datasets = load_from_disk("bart_tokenized")

In [12]:
batch_size = 8
args = Seq2SeqTrainingArguments(
    save_dir,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    gradient_accumulation_steps=2,
    weight_decay=0.01,
    save_total_limit=2,
    num_train_epochs=5,
    load_best_model_at_end=True,
    metric_for_best_model="rougeLsum_fmeasure",
    predict_with_generate=True,
    # fp16=True,
    report_to="wandb",  # enable logging to W&B
    logging_steps=100,
)

In [13]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [14]:
extract = ['rouge1', 'rouge2', 'rougeL', 'rougeLsum']
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    # print(decoded_preds)
    # print(decoded_labels)
    
    result_ = rouge(decoded_preds, decoded_labels)
    # Extract a few results
    result = {}
    for i in extract:
        result[f'{i}_fmeasure'] = result_[f'{i}_fmeasure'].item()
    
    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

In [15]:
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["dev"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [16]:
trainer.train()

  0%|          | 0/55415 [00:00<?, ?it/s]

  attn_output = torch.nn.functional.scaled_dot_product_attention(


{'loss': 1.3111, 'grad_norm': 1.3529396057128906, 'learning_rate': 1.996390868898313e-05, 'epoch': 0.01}
{'loss': 0.3844, 'grad_norm': 0.8860419392585754, 'learning_rate': 1.9927817377966256e-05, 'epoch': 0.02}
{'loss': 0.3483, 'grad_norm': 0.7416389584541321, 'learning_rate': 1.9891726066949385e-05, 'epoch': 0.03}
{'loss': 0.3442, 'grad_norm': 0.866566002368927, 'learning_rate': 1.985563475593251e-05, 'epoch': 0.04}
{'loss': 0.3238, 'grad_norm': 0.7506502270698547, 'learning_rate': 1.981954344491564e-05, 'epoch': 0.05}
{'loss': 0.3224, 'grad_norm': 0.7335077524185181, 'learning_rate': 1.9783452133898765e-05, 'epoch': 0.05}
{'loss': 0.3236, 'grad_norm': 0.7261941432952881, 'learning_rate': 1.9747360822881894e-05, 'epoch': 0.06}
{'loss': 0.3176, 'grad_norm': 0.6611402630805969, 'learning_rate': 1.971126951186502e-05, 'epoch': 0.07}
{'loss': 0.313, 'grad_norm': 0.6848235726356506, 'learning_rate': 1.9675178200848148e-05, 'epoch': 0.08}
{'loss': 0.3067, 'grad_norm': 0.9253354072570801, 'l

  0%|          | 0/1250 [00:00<?, ?it/s]

Non-default generation parameters: {'max_length': 142, 'min_length': 56, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


{'eval_loss': 0.25507110357284546, 'eval_rouge1_fmeasure': 0.4365, 'eval_rouge2_fmeasure': 0.2008, 'eval_rougeL_fmeasure': 0.3077, 'eval_rougeLsum_fmeasure': 0.3485, 'eval_gen_len': 118.2198, 'eval_runtime': 4623.0277, 'eval_samples_per_second': 2.163, 'eval_steps_per_second': 0.27, 'epoch': 1.0}
{'loss': 0.2474, 'grad_norm': 0.5918956995010376, 'learning_rate': 1.599386447712713e-05, 'epoch': 1.0}
{'loss': 0.2315, 'grad_norm': 0.5368605256080627, 'learning_rate': 1.595777316611026e-05, 'epoch': 1.01}
{'loss': 0.2382, 'grad_norm': 0.5935273766517639, 'learning_rate': 1.5921681855093386e-05, 'epoch': 1.02}
{'loss': 0.2429, 'grad_norm': 0.5816434025764465, 'learning_rate': 1.5885590544076515e-05, 'epoch': 1.03}
{'loss': 0.235, 'grad_norm': 0.5232434868812561, 'learning_rate': 1.5849499233059644e-05, 'epoch': 1.04}
{'loss': 0.244, 'grad_norm': 0.6418626308441162, 'learning_rate': 1.581340792204277e-05, 'epoch': 1.05}
{'loss': 0.2425, 'grad_norm': 0.5534032583236694, 'learning_rate': 1.577

  0%|          | 0/1250 [00:00<?, ?it/s]

Non-default generation parameters: {'max_length': 142, 'min_length': 56, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


{'eval_loss': 0.23630346357822418, 'eval_rouge1_fmeasure': 0.4238, 'eval_rouge2_fmeasure': 0.1898, 'eval_rougeL_fmeasure': 0.2966, 'eval_rougeLsum_fmeasure': 0.338, 'eval_gen_len': 122.1304, 'eval_runtime': 4677.2426, 'eval_samples_per_second': 2.138, 'eval_steps_per_second': 0.267, 'epoch': 2.0}
{'loss': 0.2234, 'grad_norm': 0.49439623951911926, 'learning_rate': 1.1987728954254266e-05, 'epoch': 2.0}
{'loss': 0.2217, 'grad_norm': 0.5656270384788513, 'learning_rate': 1.1951637643237393e-05, 'epoch': 2.01}
{'loss': 0.2284, 'grad_norm': 0.6190727949142456, 'learning_rate': 1.1915546332220517e-05, 'epoch': 2.02}
{'loss': 0.2092, 'grad_norm': 0.5607917904853821, 'learning_rate': 1.1879455021203646e-05, 'epoch': 2.03}
{'loss': 0.2201, 'grad_norm': 0.5195985436439514, 'learning_rate': 1.1843363710186773e-05, 'epoch': 2.04}
{'loss': 0.2091, 'grad_norm': 0.48761075735092163, 'learning_rate': 1.18072723991699e-05, 'epoch': 2.05}
{'loss': 0.2163, 'grad_norm': 0.6306504011154175, 'learning_rate': 

  0%|          | 0/1250 [00:00<?, ?it/s]

Non-default generation parameters: {'max_length': 142, 'min_length': 56, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


{'eval_loss': 0.2317153811454773, 'eval_rouge1_fmeasure': 0.4445, 'eval_rouge2_fmeasure': 0.2057, 'eval_rougeL_fmeasure': 0.317, 'eval_rougeLsum_fmeasure': 0.3583, 'eval_gen_len': 117.0914, 'eval_runtime': 4661.7998, 'eval_samples_per_second': 2.145, 'eval_steps_per_second': 0.268, 'epoch': 3.0}
{'loss': 0.2101, 'grad_norm': 0.5034884810447693, 'learning_rate': 7.981593431381396e-06, 'epoch': 3.0}
{'loss': 0.1991, 'grad_norm': 0.5024533271789551, 'learning_rate': 7.945502120364523e-06, 'epoch': 3.01}
{'loss': 0.2008, 'grad_norm': 0.4429325461387634, 'learning_rate': 7.90941080934765e-06, 'epoch': 3.02}
{'loss': 0.2071, 'grad_norm': 0.5512158870697021, 'learning_rate': 7.873319498330777e-06, 'epoch': 3.03}
{'loss': 0.2016, 'grad_norm': 0.5958404541015625, 'learning_rate': 7.837228187313904e-06, 'epoch': 3.04}
{'loss': 0.2031, 'grad_norm': 0.5401113033294678, 'learning_rate': 7.801136876297032e-06, 'epoch': 3.05}
{'loss': 0.2057, 'grad_norm': 0.5634111166000366, 'learning_rate': 7.765045

  0%|          | 0/1250 [00:00<?, ?it/s]

Non-default generation parameters: {'max_length': 142, 'min_length': 56, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


{'eval_loss': 0.22941601276397705, 'eval_rouge1_fmeasure': 0.4362, 'eval_rouge2_fmeasure': 0.1982, 'eval_rougeL_fmeasure': 0.3074, 'eval_rougeLsum_fmeasure': 0.3488, 'eval_gen_len': 120.1903, 'eval_runtime': 4719.6744, 'eval_samples_per_second': 2.119, 'eval_steps_per_second': 0.265, 'epoch': 4.0}
{'loss': 0.193, 'grad_norm': 0.43650463223457336, 'learning_rate': 3.975457908508527e-06, 'epoch': 4.01}
{'loss': 0.1854, 'grad_norm': 0.5162353515625, 'learning_rate': 3.939366597491654e-06, 'epoch': 4.02}
{'loss': 0.2026, 'grad_norm': 0.5320640802383423, 'learning_rate': 3.903275286474781e-06, 'epoch': 4.02}
{'loss': 0.1912, 'grad_norm': 0.6471855044364929, 'learning_rate': 3.8671839754579086e-06, 'epoch': 4.03}
{'loss': 0.1976, 'grad_norm': 0.5872098207473755, 'learning_rate': 3.831092664441037e-06, 'epoch': 4.04}
{'loss': 0.1941, 'grad_norm': 0.5220502018928528, 'learning_rate': 3.795001353424164e-06, 'epoch': 4.05}
{'loss': 0.1924, 'grad_norm': 0.60429447889328, 'learning_rate': 3.758910

  0%|          | 0/1250 [00:00<?, ?it/s]

Non-default generation parameters: {'max_length': 142, 'min_length': 56, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


{'eval_loss': 0.22959406673908234, 'eval_rouge1_fmeasure': 0.4357, 'eval_rouge2_fmeasure': 0.1989, 'eval_rougeL_fmeasure': 0.309, 'eval_rougeLsum_fmeasure': 0.3497, 'eval_gen_len': 118.7078, 'eval_runtime': 4848.1862, 'eval_samples_per_second': 2.063, 'eval_steps_per_second': 0.258, 'epoch': 5.0}


There were missing keys in the checkpoint model loaded: ['model.encoder.embed_tokens.weight', 'model.decoder.embed_tokens.weight', 'lm_head.weight'].


{'train_runtime': 82519.5697, 'train_samples_per_second': 10.744, 'train_steps_per_second': 0.672, 'train_loss': 0.22894958015771633, 'epoch': 5.0}


TrainOutput(global_step=55415, training_loss=0.22894958015771633, metrics={'train_runtime': 82519.5697, 'train_samples_per_second': 10.744, 'train_steps_per_second': 0.672, 'train_loss': 0.22894958015771633, 'epoch': 5.0})

In [17]:
finetuned_model_path = "bart_product_reviews/"

In [18]:
trainer.save_model(finetuned_model_path)

Non-default generation parameters: {'max_length': 142, 'min_length': 56, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


In [19]:
tokenizer.save_pretrained(finetuned_model_path)

('bart_product_reviews/tokenizer_config.json',
 'bart_product_reviews/special_tokens_map.json',
 'bart_product_reviews/vocab.json',
 'bart_product_reviews/merges.txt',
 'bart_product_reviews/added_tokens.json',
 'bart_product_reviews/tokenizer.json')

In [20]:
trainer.evaluate()

  0%|          | 0/1250 [00:00<?, ?it/s]

{'eval_loss': 0.2317153811454773,
 'eval_rouge1_fmeasure': 0.4445,
 'eval_rouge2_fmeasure': 0.2057,
 'eval_rougeL_fmeasure': 0.317,
 'eval_rougeLsum_fmeasure': 0.3583,
 'eval_gen_len': 117.0914,
 'eval_runtime': 4890.5137,
 'eval_samples_per_second': 2.045,
 'eval_steps_per_second': 0.256,
 'epoch': 5.0}

In [21]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
eval/gen_len,▃█▁▅▃▁
eval/loss,█▃▂▁▁▂
eval/rouge1_fmeasure,▅▁█▅▅█
eval/rouge2_fmeasure,▆▁█▅▅█
eval/rougeL_fmeasure,▅▁█▅▅█
eval/rougeLsum_fmeasure,▅▁█▅▅█
eval/runtime,▁▂▂▄▇█
eval/samples_per_second,█▇▇▅▂▁
eval/steps_per_second,█▆▇▅▂▁
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███

0,1
eval/gen_len,117.0914
eval/loss,0.23172
eval/rouge1_fmeasure,0.4445
eval/rouge2_fmeasure,0.2057
eval/rougeL_fmeasure,0.317
eval/rougeLsum_fmeasure,0.3583
eval/runtime,4890.5137
eval/samples_per_second,2.045
eval/steps_per_second,0.256
total_flos,1.921387446801531e+18
