In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [2]:
import pandas as pd
import numpy as np
import nltk

import torch
from torchmetrics.text import ROUGEScore

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer 
from datasets import Dataset, DatasetDict, load_from_disk

import wandb

In [3]:
nltk.download('punkt')

[nltk_data] Downloading package punkt to
[nltk_data]     /Users/saichandrapandraju/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [4]:
# wandb.login()

In [5]:
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"


In [11]:
data_path = "../../../data/labelled/metadata/splits/"

In [12]:
train_df = pd.read_csv(os.path.join(data_path, "train.csv"))
dev_df = pd.read_csv(os.path.join(data_path, "dev.csv"))
test_df = pd.read_csv(os.path.join(data_path, "test.csv"))

In [13]:
train_df.shape, dev_df.shape, test_df.shape

((200684, 2), (25085, 2), (25086, 2))

In [14]:
train_df = pd.concat([train_df, dev_df.head(15000)], axis=0).reset_index(drop=True)

In [15]:
dev_df = dev_df[15000:].reset_index(drop=True)

In [16]:
train_df.shape, dev_df.shape, test_df.shape

((215684, 2), (10085, 2), (25086, 2))

In [22]:
train_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 215684 entries, 0 to 215683
Data columns (total 2 columns):
 #   Column        Non-Null Count   Dtype 
---  ------        --------------   ----- 
 0   product_info  215684 non-null  object
 1   summary       215684 non-null  object
dtypes: object(2)
memory usage: 3.3+ MB


In [23]:
dataset = DatasetDict()
dataset['train'] = Dataset.from_pandas(train_df, split="train")
dataset['dev'] = Dataset.from_pandas(dev_df, split="dev")
dataset['test'] = Dataset.from_pandas(test_df, split="test")

In [24]:
dataset

DatasetDict({
    train: Dataset({
        features: ['product_info', 'summary'],
        num_rows: 215684
    })
    dev: Dataset({
        features: ['product_info', 'summary'],
        num_rows: 10085
    })
    test: Dataset({
        features: ['product_info', 'summary'],
        num_rows: 25086
    })
})

In [17]:
model_name = "facebook/bart-large-cnn"

save_dir = "bart-product-info-summarization"
os.environ["WANDB_PROJECT"]=f"bart_large_product_info"
os.environ["WANDB_LOG_MODEL"]="false"
os.environ["WANDB_WATCH"]="false"
os.environ["WANDB_NOTEBOOK_NAME"]="bart-large.ipynb"

In [18]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

In [19]:
%%capture
model.to(device)

In [20]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [21]:
rouge = ROUGEScore()

In [22]:
max_input_length = 512
max_target_length = 256

def preprocess_function(examples):
    inputs = [doc for doc in examples["product_info"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, padding=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True, padding=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [23]:
tokenized_datasets = dataset.map(preprocess_function, batched=True)

In [24]:
tokenized_datasets.save_to_disk("bart_tokenized")

In [25]:
tokenized_datasets = load_from_disk("bart_tokenized")

In [26]:
batch_size = 16
args = Seq2SeqTrainingArguments(
    save_dir,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    gradient_accumulation_steps=2,
    weight_decay=0.01,
    save_total_limit=2,
    num_train_epochs=5,
    load_best_model_at_end=True,
    metric_for_best_model="rougeLsum_fmeasure",
    predict_with_generate=True,
    # fp16=True,
    report_to="wandb",  # enable logging to W&B
    logging_steps=100,
)

In [29]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [30]:
extract = ['rouge1', 'rouge2', 'rougeL', 'rougeLsum']
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    # print(decoded_preds)
    # print(decoded_labels)
    
    result_ = rouge(decoded_preds, decoded_labels)
    # Extract a few results
    result = {}
    for i in extract:
        result[f'{i}_fmeasure'] = result_[f'{i}_fmeasure'].item()
    
    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

In [None]:
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["dev"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
trainer.train()

In [33]:
finetuned_model_path = "bart_product_info1/"

In [34]:
trainer.save_model(finetuned_model_path)

Non-default generation parameters: {'max_length': 142, 'min_length': 56, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


In [35]:
tokenizer.save_pretrained(finetuned_model_path)

('bart_product_info1/tokenizer_config.json',
 'bart_product_info1/special_tokens_map.json',
 'bart_product_info1/vocab.json',
 'bart_product_info1/merges.txt',
 'bart_product_info1/added_tokens.json',
 'bart_product_info1/tokenizer.json')

In [36]:
trainer.evaluate()

  0%|          | 0/631 [00:00<?, ?it/s]

{'eval_loss': 0.1599269062280655,
 'eval_rouge1_fmeasure': 0.7063,
 'eval_rouge2_fmeasure': 0.5353,
 'eval_rougeL_fmeasure': 0.6225,
 'eval_rougeLsum_fmeasure': 0.6619,
 'eval_gen_len': 82.9191,
 'eval_runtime': 5640.8785,
 'eval_samples_per_second': 1.788,
 'eval_steps_per_second': 0.112,
 'epoch': 5.0}

In [37]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
eval/gen_len,▁█▂▃▁▁
eval/loss,█▅▂▁▁▁
eval/rouge1_fmeasure,▁▄▆███
eval/rouge2_fmeasure,▁▄▆███
eval/rougeL_fmeasure,▁▄▆▇██
eval/rougeLsum_fmeasure,▁▄▆███
eval/runtime,▁▁▁▁█▄
eval/samples_per_second,█▇██▁▃
eval/steps_per_second,█▇██▁▃
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇████

0,1
eval/gen_len,82.9191
eval/loss,0.15993
eval/rouge1_fmeasure,0.7063
eval/rouge2_fmeasure,0.5353
eval/rougeL_fmeasure,0.6225
eval/rougeLsum_fmeasure,0.6619
eval/runtime,5640.8785
eval/samples_per_second,1.788
eval/steps_per_second,0.112
total_flos,1.1684507909483397e+18
