In [1]:
# !pip install transformers
# !pip install datasets
# !pip install sentencepiece
# !pip install rouge_score
# !pip install wandb

In [2]:
import torch
import numpy as np
import datasets

from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
)

from tabulate import tabulate
import nltk
from datetime import datetime

2023-07-11 21:27:58.377026: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
WANDB_INTEGRATION = True
if WANDB_INTEGRATION:
    import wandb
    wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mrmk[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
language = 'english'
# language = 'french'

In [5]:
model_name = "sshleifer/distilbart-xsum-12-3"
if language == 'french':
    model_name = "moussaKam/barthez-orangesum-abstract"

model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# print(model.config)

encoder_max_length = 256
decoder_max_length = 64

In [6]:
data = datasets.load_dataset("wiki_lingua", name=language, split="train[:2000]")

for k, v in data['article'][0].items():
    print(k)
    print(v)

Found cached dataset wiki_lingua (/Users/rajithamuthukrishnan/.cache/huggingface/datasets/wiki_lingua/english/1.1.1/6fdaa844abe35a3a2a79e5a1cf9e546f32ad234d59756bcf9cfeadff6c89240e)


section_name
['Finding Other Transportation', 'Designating a Driver', 'Staying Safe']
document
['make sure that the area is a safe place, especially if you plan on walking home at night.  It’s always a good idea to practice the buddy system.  Have a friend meet up and walk with you. Research the bus, train, or streetcar routes available in your area to find safe and affordable travel to your destination.  Make sure you check the schedule for your outgoing and return travel.  Some public transportation will cease to run late at night.  Be sure if you take public transportation to the venue that you will also be able to get home late at night. Check the routes.  Even if some public transit is still running late at night, the routing may change.  Some may run express past many of the stops, or not travel all the way to the ends.  Be sure that your stop will still be available when you need it for your return trip. If you are taking public transit in a vulnerable state after drinking, it i

In [7]:
def flatten(example):
    return {
        'document': example['article']['document'],
        'summary': example['article']['summary'],
    }

def list2samples(example):
    documents = []
    summaries = []
    for sample in zip(example['document'], example['summary']):
        if len(sample[0]) > 0:
            documents += sample[0]
            summaries += sample[1]
    return {'document': documents, 'summary': summaries}

dataset = data.map(flatten, remove_columns=['article', 'url'])
dataset = dataset.map(list2samples, batched=True)

train_data_txt, validation_data_txt = dataset.train_test_split(test_size=0.1).values()

Loading cached processed dataset at /Users/rajithamuthukrishnan/.cache/huggingface/datasets/wiki_lingua/english/1.1.1/6fdaa844abe35a3a2a79e5a1cf9e546f32ad234d59756bcf9cfeadff6c89240e/cache-f8911c97f741fef9.arrow
Loading cached processed dataset at /Users/rajithamuthukrishnan/.cache/huggingface/datasets/wiki_lingua/english/1.1.1/6fdaa844abe35a3a2a79e5a1cf9e546f32ad234d59756bcf9cfeadff6c89240e/cache-6b52a63406610f48.arrow


In [8]:
def batch_tokenize_preprocess(batch, tokenizer, max_source_length, max_target_length):
    source, target = batch['document'], batch['summary']
    source_tokenized = tokenizer(
        source, padding = 'max_length', truncation=True, max_length=max_source_length
    )
    target_tokenized = tokenizer(
        target, padding = 'max_length', truncation=True, max_length=max_target_length
    )
    
    batch = {k: v for k,v in source_tokenized.items()}
#     Ignore padding in the loss
    batch['labels'] = [
        [-100 if token == tokenizer.pad_token_id else token for token in l]
        for l in target_tokenized['input_ids']
    ]
    return batch

train_data = train_data_txt.map(
    lambda batch: batch_tokenize_preprocess(
        batch, tokenizer, encoder_max_length, decoder_max_length
    ),
    batched=True,
    remove_columns=train_data_txt.column_names,
)

validation_data = validation_data_txt.map(
    lambda batch: batch_tokenize_preprocess(
        batch, tokenizer, encoder_max_length, decoder_max_length
    ),
    batched=True,
    remove_columns=validation_data_txt.column_names,
)

Map:   0%|          | 0/4351 [00:00<?, ? examples/s]

Map:   0%|          | 0/484 [00:00<?, ? examples/s]

In [15]:
nltk.download('punkt', quiet=True)

metric = datasets.load_metric("rouge")


def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]
#     rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
    result = metric.compute(
        predictions=decoded_preds, references=decoded_labels, use_stemmer=True
    )
    # Extract a few results from ROUGE
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
    ]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

In [16]:
training_args = Seq2SeqTrainingArguments(
    output_dir="results",
    num_train_epochs=1,  # demo
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=4,  # demo
    per_device_eval_batch_size=4,
    # learning_rate=3e-05,
    warmup_steps=500,
    weight_decay=0.1,
    label_smoothing_factor=0.1,
    predict_with_generate=True,
    logging_dir="logs",
    logging_steps=50,
    save_total_limit=3,
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_data,
    eval_dataset=validation_data,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [17]:
if WANDB_INTEGRATION:
    wandb_run = wandb.init(
        project="text_summarizer_bart",
        config={
            "per_device_train_batch_size": training_args.per_device_train_batch_size,
            "learning_rate": training_args.learning_rate,
            "dataset": "wiki_lingua " + language,
        },
    )

    now = datetime.now()
    current_time = now.strftime("%H%M%S")
    wandb_run.name = "run_" + language + "_" + current_time

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016674224633334234, max=1.0…

In [22]:
trainer.evaluate()

{'eval_loss': 6.732614994049072,
 'eval_rouge1': 20.2427,
 'eval_rouge2': 4.8186,
 'eval_rougeL': 15.2817,
 'eval_rougeLsum': 18.0839,
 'eval_gen_len': 23.7045}

In [25]:
def generate_summary(test_samples, model):
    inputs = tokenizer(
        test_samples['document'],
        padding='max_length',
        truncation=True,
        max_length=encoder_max_length,
        return_tensors='pt',
    )
    input_ids = inputs.input_ids.to(model.device)
    attention_mask = inputs.attention_mask.to(model.device)
    outputs = model.generate(input_ids, attention_mask=attention_mask)
    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return outputs, output_str

In [26]:
model_before_tuning = AutoModelForSeq2SeqLM.from_pretrained(model_name)
test_samples = validation_data_txt.select(range(16))
summaries_before_tuning = generate_summary(test_samples, model_before_tuning)[1]
# summaries_after_tuning = generate_summary(test_samples, model)[1]



In [27]:
print(
    tabulate(
        zip(
            range(len(summaries_before_tuning)),
            summaries_before_tuning,
        ),
        headers=["Id", "Summary before"],
    )
)
print("\nTarget summaries:\n")
print(
    tabulate(list(enumerate(test_samples["summary"])), headers=["Id", "Target summary"])
)
print("\nSource documents:\n")
print(tabulate(list(enumerate(test_samples["document"])), headers=["Id", "Document"]))

  Id  Summary before
----  ---------------------------------------------------------------------------------------------------------------------------------------------------
   0  Here's a guide to the commands shown in the Windows Start menu.
   1  If you want to buy a parcel of land for your home, you will need to have the land surveyed.
   2  In our series of letters from African journalists, film-maker and columnist Farai Sevenzo looks at some of the country ham slices.
   3  Transposing chords to a song to make it easier to play on another instrument, such as guitar.
   4  If your alpaca is not broken into a halter, you will have to do it yourself.
   5  If you're looking for a friend to post on Facebook, what do you do?
   6  If you are writing an event report for an agency, you might want to do it pretty quickly.
   7  In our series of letters from African journalists, film-maker and columnist Farai Sevenzo offers his tips for belly dancing.
   8  If you want to play on Xbox LI

In [28]:
# %%wandb
trainer.train()

Step,Training Loss
50,6.3856
100,5.5107
150,5.137
200,4.9096
250,4.7959
300,4.8004
350,4.7057
400,4.7021
450,4.5789
500,4.6215


TrainOutput(global_step=1088, training_loss=4.712767306496115, metrics={'train_runtime': 12412.3614, 'train_samples_per_second': 0.351, 'train_steps_per_second': 0.088, 'total_flos': 1354407936000000.0, 'train_loss': 4.712767306496115, 'epoch': 1.0})

In [29]:
trainer.evaluate()

{'eval_loss': 4.219376564025879,
 'eval_rouge1': 31.0247,
 'eval_rouge2': 11.9114,
 'eval_rougeL': 25.3217,
 'eval_rougeLsum': 29.7556,
 'eval_gen_len': 22.6488,
 'eval_runtime': 1566.0746,
 'eval_samples_per_second': 0.309,
 'eval_steps_per_second': 0.077,
 'epoch': 1.0}

In [30]:
if WANDB_INTEGRATION:
    wandb_run.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/gen_len,██▁
eval/loss,██▁
eval/rouge1,▁▁█
eval/rouge2,▁▁█
eval/rougeL,▁▁█
eval/rougeLsum,▁▁█
eval/runtime,▁▃█
eval/samples_per_second,█▆▁
eval/steps_per_second,█▆▁
train/epoch,▁▁▂▂▂▃▃▃▄▄▄▅▅▅▅▆▆▆▇▇▇███

0,1
eval/gen_len,22.6488
eval/loss,4.21938
eval/rouge1,31.0247
eval/rouge2,11.9114
eval/rougeL,25.3217
eval/rougeLsum,29.7556
eval/runtime,1566.0746
eval/samples_per_second,0.309
eval/steps_per_second,0.077
train/epoch,1.0


In [31]:
# def generate_summary(test_samples, model):
#     inputs = tokenizer(
#         test_samples['document'],
#         padding='max_length',
#         truncation=True,
#         max_length=encoder_max_length,
#         return_tensors='pt',
#     )
#     input_ids = inputs.input_ids.to(model.device)
#     attention_mask = inputs.attention_mask.to(model.device)
#     outputs = model.generate(input_ids, attention_mask=attention_mask)
#     output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
#     return outputs, output_str

# model_before_tuning = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# test_samples = validation_data_txt.select(range(16))
# summaries_before_tuning = generate_summary(test_samples, model_before_tuning)[1]
summaries_after_tuning = generate_summary(test_samples, model)[1]



In [33]:
summaries_after_tuning

['Open Windows. Click the Command Prompt menu. Press the drive letter. Press "D:" and press Enter. Press Enter.',
 'Have the land surveyed to determine the dimensions of the plot of land.',
 'Cut the ham slice from the edges of the ham. Heat the pan on medium-high.',
 'Transpose chords up or down. Use a transposition table. Use the chromatic circle. Play all of the chords in the song.',
 'Get an alpaca halter. Feed in a feeder. Feed your alpacas.',
 "Open the Facebook app. Tap the search bar. Tap Facebook. Tap your friend's profile. Tap Post. Tap a post window. Tap Instagram.",
 'Set a deadline for the report. Write a report. Be thorough and professional.',
 'Bend your knees and legs under your hip bones. Do the “shimmy” move.',
 'Press the "On" button on your Xbox 360 or One controller. Check the status of your Xbox One\'s Internet connection.',
 'Think about your purpose. Use a name for God. Practice singing devotional songs.',
 'Open the Start button. Click Command Prompt. Click the

In [32]:
print(
    tabulate(
        zip(
            range(len(summaries_after_tuning)),
            summaries_after_tuning,
            summaries_before_tuning,
        ),
        headers=["Id", "Summary after", "Summary before"],
    )
)
print("\nTarget summaries:\n")
print(
    tabulate(list(enumerate(test_samples["summary"])), headers=["Id", "Target summary"])
)
print("\nSource documents:\n")
print(tabulate(list(enumerate(test_samples["document"])), headers=["Id", "Document"]))

  Id  Summary after                                                                                                                    Summary before
----  -------------------------------------------------------------------------------------------------------------------------------  ---------------------------------------------------------------------------------------------------------------------------------------------------
   0  Open Windows. Click the Command Prompt menu. Press the drive letter. Press "D:" and press Enter. Press Enter.                    Here's a guide to the commands shown in the Windows Start menu.
   1  Have the land surveyed to determine the dimensions of the plot of land.                                                          If you want to buy a parcel of land for your home, you will need to have the land surveyed.
   2  Cut the ham slice from the edges of the ham. Heat the pan on medium-high.                                                        In our