In [None]:
from datasets import load_dataset
dataset = load_dataset("./COCO")

In [None]:
train_10pct = load_dataset("./COCO", split="train[:10%]")
dataset["train"] = train_10pct

In [None]:
def get_raw_data(data):
    data["prompt"] = data["sentences"]["raw"]
    return data

dataset = dataset.map(
    get_raw_data,
    remove_columns=["filepath", "sentids", "filename", "imgid", "split", "cocoid", "sentences"],
)

In [None]:
from keybert import KeyBERT

kw_model = KeyBERT("all-mpnet-base-v2")

In [None]:
def extract_keywords(batch):

    keywords = kw_model.extract_keywords(batch["prompt"], keyphrase_ngram_range=(1, 2), stop_words=None)
    keywords = list(map(lambda keyword: [x[0] for x in keyword if x[1] >= 0.4][:2], keywords))
    concat_keywords = [", ".join(keyword) for keyword in keywords]
    batch["keywords"] = concat_keywords

    return batch

In [None]:
dataset = dataset.map(extract_keywords, batched=True, batch_size=128)

In [None]:
import torch
import numpy as np
import datasets

from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
)

from tabulate import tabulate
import nltk
from datetime import datetime

In [None]:
encoder_max_length = 512
decoder_max_length = 512

In [None]:
model_name = "facebook/bart-base"

model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
def batch_tokenize_preprocess(batch, tokenizer, max_source_length, max_target_length):
    source_tokenized = tokenizer(
        batch["keywords"], padding="max_length", truncation=True, max_length=max_source_length
    )
    target_tokenized = tokenizer(
        batch["prompt"], padding="max_length", truncation=True, max_length=max_target_length
    )

    batch = {k: v for k, v in source_tokenized.items()}
    # Ignore padding in the loss
    batch["labels"] = [
        [-100 if token == tokenizer.pad_token_id else token for token in l]
        for l in target_tokenized["input_ids"]
    ]
    return batch

tokenized_dataset = dataset.map(
    lambda batch: batch_tokenize_preprocess(
        batch, tokenizer, encoder_max_length, decoder_max_length
    ),
    batched=True,
    remove_columns=dataset["train"].column_names
)

In [None]:
metric = datasets.load_metric("rouge")

In [None]:
nltk.download("punkt", quiet=True)

metric = datasets.load_metric("rouge")


def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    # Replace -100 in the labels as we can"t decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(
        predictions=decoded_preds, references=decoded_labels, use_stemmer=True
    )
    # Extract a few results from ROUGE
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}

    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
    ]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="results",
    num_train_epochs=5,
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=3e-05,
    warmup_steps=500,
    weight_decay=0.1,
    label_smoothing_factor=0.1,
    predict_with_generate=True,
    logging_dir="logs",
    logging_steps=500,
    save_total_limit=3,
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=tokenized_dataset["validation"],
    eval_dataset=tokenized_dataset["test"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.evaluate()

In [None]:
trainer.train()

In [None]:
trainer.evaluate()

In [None]:
keywords = "brige builder, game mobile, splash screen, stylized"

inputs = tokenizer(
    keywords,
    padding="max_length",
    truncation=True,
    max_length=encoder_max_length,
    return_tensors="pt",
)

input_ids = inputs.input_ids.to(model.device)
attention_mask = inputs.attention_mask.to(model.device)
outputs = model.generate(input_ids, attention_mask=attention_mask)
output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(output_str)

In [None]:
def generate_summary(test_samples, model):
    inputs = tokenizer(
        test_samples["keywords"],
        padding="max_length",
        truncation=True,
        max_length=encoder_max_length,
        return_tensors="pt",
    )
    input_ids = inputs.input_ids.to(model.device)
    attention_mask = inputs.attention_mask.to(model.device)
    outputs = model.generate(input_ids, attention_mask=attention_mask)
    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return outputs, output_str


model_before_tuning = AutoModelForSeq2SeqLM.from_pretrained(model_name)

test_samples = dataset["test"].select(range(10))

summaries_before_tuning = generate_summary(test_samples, model_before_tuning)[1]
summaries_after_tuning = generate_summary(test_samples, model)[1]

In [None]:
print(
    tabulate(
        zip(
            range(len(summaries_after_tuning)),
            summaries_after_tuning,
            summaries_before_tuning,
        ),
        headers=["Id", "Generated Text", "keywords"],
    )
)
print("\nTarget text:\n")
print(
    tabulate(list(enumerate(test_samples["prompt"])), headers=["Id", "Target text"])
)
print("\nSource documents:\n")
print(tabulate(list(enumerate(test_samples["keywords"])), headers=["Id", "Original text"]))