In [50]:
from datasets import load_dataset
dataset = load_dataset("51la5/keyword-extraction")

In [51]:
dataset

DatasetDict({
    train: Dataset({
        features: ['dataset', 'file_id', 'text', 'summary', 'type'],
        num_rows: 22033
    })
    test: Dataset({
        features: ['dataset', 'file_id', 'text', 'summary', 'type'],
        num_rows: 5513
    })
})

In [52]:
set(dataset['train']["dataset"])

{'Inspec',
 'Krapivin2009',
 'Nguyen2007',
 None,
 'PubMed',
 'QMSum',
 'Schutz2008',
 'SemEval2010',
 'SemEval2017',
 'citeulike180',
 'fao30',
 'fao780',
 'kdd',
 'theses100',
 'wiki20',
 'www'}

In [53]:
dataset = dataset.filter(lambda example: example["type"] == "KEYWORD")

In [54]:
dataset

DatasetDict({
    train: Dataset({
        features: ['dataset', 'file_id', 'text', 'summary', 'type'],
        num_rows: 8139
    })
    test: Dataset({
        features: ['dataset', 'file_id', 'text', 'summary', 'type'],
        num_rows: 2033
    })
})

In [56]:
dataset = dataset.filter(lambda example: len(example["text"]) <= 1000)

Filter: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8139/8139 [00:00<00:00, 45757.15 examples/s]
Filter: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2033/2033 [00:00<00:00, 51768.01 examples/s]


In [57]:
dataset

DatasetDict({
    train: Dataset({
        features: ['dataset', 'file_id', 'text', 'summary', 'type'],
        num_rows: 2522
    })
    test: Dataset({
        features: ['dataset', 'file_id', 'text', 'summary', 'type'],
        num_rows: 630
    })
})

In [58]:
dataset = dataset.filter(lambda example: len(example["text"]) / len(example["summary"]) <= 10)

Filter: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2522/2522 [00:00<00:00, 105500.77 examples/s]
Filter: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 630/630 [00:00<00:00, 96572.31 examples/s]


In [59]:
dataset

DatasetDict({
    train: Dataset({
        features: ['dataset', 'file_id', 'text', 'summary', 'type'],
        num_rows: 2319
    })
    test: Dataset({
        features: ['dataset', 'file_id', 'text', 'summary', 'type'],
        num_rows: 571
    })
})

In [60]:
import torch
import numpy as np
import datasets

from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
)

from tabulate import tabulate
import nltk
from datetime import datetime

In [61]:
WANDB_INTEGRATION = True
if WANDB_INTEGRATION:
    import wandb

    wandb.login()

In [62]:
model_name = "facebook/bart-base"

model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [63]:
encoder_max_length = 512  # demo
decoder_max_length = 512

In [70]:
dataset['train'][0]

{'dataset': 'kdd',
 'file_id': '1370349',
 'text': 'Mining a stream of transactions for customer patterns No contact information provided yet.',
 'summary': 'approximate queries,customer profiles,dynamic database,histograms,incremental updates,massive data,signatures',
 'type': 'KEYWORD'}

In [71]:
def batch_tokenize_preprocess(batch, tokenizer, max_source_length, max_target_length):
    source_tokenized = tokenizer(
        batch["summary"], padding="max_length", truncation=True, max_length=max_source_length
    )
    target_tokenized = tokenizer(
        batch["text"], padding="max_length", truncation=True, max_length=max_target_length
    )

    batch = {k: v for k, v in source_tokenized.items()}
    # Ignore padding in the loss
    batch["labels"] = [
        [-100 if token == tokenizer.pad_token_id else token for token in l]
        for l in target_tokenized["input_ids"]
    ]
    return batch

tokenized_dataset = dataset.map(
    lambda batch: batch_tokenize_preprocess(
        batch, tokenizer, encoder_max_length, decoder_max_length
    ),
    batched=True,
    remove_columns=dataset["train"].column_names
)

Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2319/2319 [00:00<00:00, 2817.54 examples/s]
Map: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 571/571 [00:00<00:00, 2918.48 examples/s]


In [72]:
dataset["train"][0]

{'dataset': 'kdd',
 'file_id': '1370349',
 'text': 'Mining a stream of transactions for customer patterns No contact information provided yet.',
 'summary': 'approximate queries,customer profiles,dynamic database,histograms,incremental updates,massive data,signatures',
 'type': 'KEYWORD'}

In [73]:
# !pip install rouge-score
metric = datasets.load_metric("rouge")

  metric = datasets.load_metric("rouge")


In [74]:
nltk.download("punkt", quiet=True)

metric = datasets.load_metric("rouge")


def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(
        predictions=decoded_preds, references=decoded_labels, use_stemmer=True
    )
    # Extract a few results from ROUGE
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}

    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
    ]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

In [75]:
training_args = Seq2SeqTrainingArguments(
    output_dir="results",
    num_train_epochs=20,  # demo
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=16, # demo
    per_device_eval_batch_size=16,
    # learning_rate=3e-05,
    warmup_steps=500,
    weight_decay=0.1,
    label_smoothing_factor=0.1,
    predict_with_generate=True,
    logging_dir="logs",
    logging_steps=50,
    save_total_limit=3,
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=tokenized_dataset['train'],
    eval_dataset=tokenized_dataset['test'],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [76]:
wandb_run = wandb.init(
    project="bart prompt generator",
    config={
        "per_device_train_batch_size": training_args.per_device_train_batch_size,
        "learning_rate": training_args.learning_rate,
        "dataset": "51la5/keyword-extraction",
    },
)

now = datetime.now()
current_time = now.strftime("%H%M%S")
wandb_run.name = "run_" + current_time

In [77]:
trainer.evaluate()

You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'eval_loss': 7.8832197189331055,
 'eval_rouge1': 20.3832,
 'eval_rouge2': 8.3047,
 'eval_rougeL': 16.7813,
 'eval_rougeLsum': 17.2561,
 'eval_gen_len': 18.6865,
 'eval_runtime': 18.6698,
 'eval_samples_per_second': 30.584,
 'eval_steps_per_second': 1.928}

In [78]:
trainer.train()



Step,Training Loss
50,6.3416
100,5.329
150,5.058
200,4.9276
250,4.8411
300,4.7648
350,4.657
400,4.6044
450,4.5668
500,4.4182


TrainOutput(global_step=2900, training_loss=3.7804891020676186, metrics={'train_runtime': 1620.3423, 'train_samples_per_second': 28.624, 'train_steps_per_second': 1.79, 'total_flos': 1.41397884665856e+16, 'train_loss': 3.7804891020676186, 'epoch': 20.0})

In [79]:
trainer.evaluate()

{'eval_loss': 4.468584060668945,
 'eval_rouge1': 33.2493,
 'eval_rouge2': 21.6726,
 'eval_rougeL': 30.3571,
 'eval_rougeLsum': 31.135,
 'eval_gen_len': 18.4221,
 'eval_runtime': 18.5853,
 'eval_samples_per_second': 30.723,
 'eval_steps_per_second': 1.937,
 'epoch': 20.0}

In [88]:
trainer.evaluate(tokenized_dataset["train"])

{'eval_loss': 2.7313904762268066,
 'eval_rouge1': 34.7148,
 'eval_rouge2': 24.2124,
 'eval_rougeL': 32.4001,
 'eval_rougeLsum': 33.0395,
 'eval_gen_len': 18.5054,
 'eval_runtime': 73.9681,
 'eval_samples_per_second': 31.351,
 'eval_steps_per_second': 1.96,
 'epoch': 20.0}

In [87]:
keywords = "cat, clean, dog, dirty"
inputs = tokenizer(
    keywords,
    padding="max_length",
    truncation=True,
    max_length=encoder_max_length,
    return_tensors="pt",
)
input_ids = inputs.input_ids.to(model.device)
attention_mask = inputs.attention_mask.to(model.device)
outputs = model.generate(input_ids, attention_mask=attention_mask)
output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(output_str)

['Cat and mouse: getting the cat and the dog dirty No contact information provided yet.']


In [82]:
def generate_summary(test_samples, model):
    inputs = tokenizer(
        test_samples["summary"],
        padding="max_length",
        truncation=True,
        max_length=encoder_max_length,
        return_tensors="pt",
    )
    input_ids = inputs.input_ids.to(model.device)
    attention_mask = inputs.attention_mask.to(model.device)
    outputs = model.generate(input_ids, attention_mask=attention_mask)
    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return outputs, output_str


model_before_tuning = AutoModelForSeq2SeqLM.from_pretrained(model_name)

test_samples = dataset["test"].select(range(16))

summaries_before_tuning = generate_summary(test_samples, model_before_tuning)[1]
summaries_after_tuning = generate_summary(test_samples, model)[1]

In [85]:
test_samples[0]

{'dataset': 'Inspec',
 'file_id': '109',
 'text': 'An entanglement measure based on the capacity of dense codingAn asymptotic entanglement measure for any bipartite states is derived in thelight of the dense coding capacity optimized with respect to localquantum operations and classical communications. General properties andsome examples with explicit forms of this entanglement measure areinvestigated',
 'summary': 'entanglement measure,dense coding capacity,asymptotic entanglement measure,bipartite states,local quantum operations,classical communications,optimization,encoding,optimisation,quantum communication',
 'type': 'KEYWORD'}

In [83]:
print(
    tabulate(
        zip(
            range(len(summaries_after_tuning)),
            summaries_after_tuning,
            summaries_before_tuning,
        ),
        headers=["Id", "Generated Prompt", "Keywords"],
    )
)
print("\nTarget text:\n")
print(
    tabulate(list(enumerate(test_samples["text"])), headers=["Id", "Target text"])
)
print("\nSource documents:\n")
print(tabulate(list(enumerate(test_samples["summary"])), headers=["Id", "summary"]))

  Id  Generated Prompt                                                                                                        Keywords
----  ----------------------------------------------------------------------------------------------------------------------  ---------------------------------------------------------------------------------------------------
   0  An asymptotic entanglement measure for quantum operations with a dense codingcapacity                                   entanglement measure,dense coding capacity,asymptotic entanglement
   1  Geographic location of servers in africa: a digital-divide approach No contact                                          africa,cctld,digital-divide,geographic location of servers
   2  A voltage-vector selection algorithm for direct torque control of induction motordrivesA new                            voltage-vector selection algorithm,direct torque control,induction motor iphone
   3  An agent communication languages for the semantic