In [None]:
%%capture

!pip install datasets transformers[torch] evaluate rouge_score bert_score accelerate

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
import torch
from datasets import load_dataset

multi_lexsum = load_dataset("allenai/multi_lexsum", name="v20230518")

Downloading data:   0%|          | 0.00/378M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/219M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/94.6M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/145M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/3177 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/454 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/908 [00:00<?, ? examples/s]

In [None]:
def filtered_func(data):
    inp = max(data['sources'], key = lambda x: len(x))
    if len(inp.split(" ")) > 4096:
        return False
    return True

multi_lexsum_filtered = multi_lexsum.filter(filtered_func)

Filter:   0%|          | 0/3177 [00:00<?, ? examples/s]

Filter:   0%|          | 0/454 [00:00<?, ? examples/s]

Filter:   0%|          | 0/908 [00:00<?, ? examples/s]

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

model_ckpt = "pszemraj/led-large-book-summary"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModelForSeq2SeqLM.from_pretrained(model_ckpt)

tokenizer_config.json:   0%|          | 0.00/1.32k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/772 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.44k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.84G [00:00<?, ?B/s]

In [None]:
def preprocess_func(examples):
    source = max(examples['sources'], key = lambda x: len(x))
    source = "Summarize: " + source
    target = examples['summary/long']

    tokenized_inputs = tokenizer(source, max_length= 4096, padding='max_length', truncation=True)
    tokenized_targets = tokenizer(target, max_length= 700, padding='max_length', truncation=True)

    return {
        'input_ids': tokenized_inputs['input_ids'],
        'attention_mask': tokenized_inputs['attention_mask'],
        'labels': tokenized_targets['input_ids']
    }

def process_data(data):
    return data.remove_columns(['id','sources_metadata', 'summary/short', 'summary/tiny', 'case_metadata']).map(preprocess_func)

multi_lexsum_filtered = process_data(multi_lexsum_filtered)

Map:   0%|          | 0/850 [00:00<?, ? examples/s]

Map:   0%|          | 0/125 [00:00<?, ? examples/s]

Map:   0%|          | 0/254 [00:00<?, ? examples/s]

In [None]:
import evaluate
import numpy as np

rouge = evaluate.load("rouge")
bertscore = evaluate.load("bertscore")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    rouge_score = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    bert_score = bertscore.compute(predictions = decoded_preds, references = decoded_labels, lang='en')
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]

    score = {}
    for k, v in rouge_score.items():
      score[k] = round(v, 4)

    for k, v in bert_score.items():
      if k in ["f1", "precision", "recall"]:
        val = float(v[0])
        score[f"bert_{k}"] = round(val, 4)

    score["gen_len"] = np.mean(prediction_lens)

    return score

Downloading builder script:   0%|          | 0.00/7.95k [00:00<?, ?B/s]

In [None]:
from transformers import DataCollatorForSeq2Seq
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model = model)

In [None]:
import torch
torch.cuda.empty_cache()

In [None]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, GenerationConfig

generation_config = GenerationConfig(
      max_length = 700,
      min_length = 400,
      num_beams=9,
      temperature = 0.8,
      do_sample=True,
      length_penalty = 1.0,
      use_cache=True,
      early_stopping=True,
      no_repeat_ngram_size = 3,
      repetition_penalty = 3.5,
      bos_token_id = 0,
      decoder_start_token_id = 2,
      eos_token_id = 2,
      pad_token_id = 1
    )

training_args = Seq2SeqTrainingArguments(
    output_dir="LED_multi_lexsum_peft",
    generation_config=generation_config,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size= 1,
    per_device_eval_batch_size= 1,
    save_total_limit=3,
    num_train_epochs=2,
    predict_with_generate=True,
    fp16 = False,
    push_to_hub=False
)

trainer = Seq2SeqTrainer(
    model = model,
    args=training_args,
    train_dataset=multi_lexsum_filtered["train"],
    eval_dataset=multi_lexsum_filtered["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

In [None]:
trainer.evaluate(eval_dataset = multi_lexsum_filtered["test"].select(range(1)))

{'eval_loss': 9.576332092285156,
 'eval_rouge1': 0.4608,
 'eval_rouge2': 0.1354,
 'eval_rougeL': 0.1761,
 'eval_rougeLsum': 0.2654,
 'eval_bert_precision': 0.8346,
 'eval_bert_recall': 0.8178,
 'eval_bert_f1': 0.8261,
 'eval_gen_len': 609.0,
 'eval_runtime': 38.1155,
 'eval_samples_per_second': 0.026,
 'eval_steps_per_second': 0.026}

In [None]:
def summarize_text(text):
  text = "Summarize:" + text
  input_ids = tokenizer(text, max_length= 4096, padding='max_length', truncation=True, return_tensors="pt").to(model.device)

  with torch.inference_mode():
    outputs = model.generate(**input_ids, generation_config = generation_config)

  summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
  return summary

In [None]:
from pprint import pprint

index = 5
data = multi_lexsum_filtered["test"][index]

input_text = max(data['sources'], key = lambda x: len(x))
original_summary = data['summary/long']
prediction_summary = summarize_text(input_text)

In [None]:
pprint(original_summary)

In [None]:
pprint(prediction_summary)

In [None]:
pprint(input_text)

('On May 27, 2015, this lawsuit was brought in the United States for the '
 'Eastern District of Missouri by a person arrested by the City of St. Ann '
 '(the City), Missouri, who was jailed for a prolonged period after he was '
 'unable to pay the fee demanded for his release under the city’s “secured '
 'bail” policy. Under that policy, persons arrested for ordinance violations '
 'were required to post a bail from $150-350 or spend upwards of 3 days in '
 'jail, without any consideration of the person’s ability to pay. The '
 'plaintiff argued that the City’s policy violated the Equal Protection and '
 'Due Process Clauses of the Fourteenth Amendment of the U.S. Constitution. '
 'Represented by public interest organizations ArchCity Defenders and Equal '
 'Justice Under Law, the plaintiff brought suit in the U.S. District Court for '
 'the Eastern District of Missouri, under 42 U.S.C. § 1983. The plaintiff '
 'asked the court for class certification to represent other similarly '
 '