In [None]:
%%capture

!pip install datasets transformers[torch] bitsandbytes evaluate rouge_score accelerate peft

In [None]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
import torch
from datasets import load_dataset

multi_lexsum = load_dataset("allenai/multi_lexsum", name="v20230518")

Downloading data:   0%|          | 0.00/378M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/219M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/94.6M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/145M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/3177 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/454 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/908 [00:00<?, ? examples/s]

In [None]:
def filtered_func(data):
    inp = max(data['sources'], key = lambda x: len(x))
    if len(inp.split(" ")) > 4096:
        return False
    return True

multi_lexsum_filtered = multi_lexsum.filter(filtered_func)

Filter:   0%|          | 0/3177 [00:00<?, ? examples/s]

Filter:   0%|          | 0/454 [00:00<?, ? examples/s]

Filter:   0%|          | 0/908 [00:00<?, ? examples/s]

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

model_ckpt = "pszemraj/led-large-book-summary"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModelForSeq2SeqLM.from_pretrained(model_ckpt)

tokenizer_config.json:   0%|          | 0.00/1.32k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/772 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.44k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.84G [00:00<?, ?B/s]

In [None]:
from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=16,
    lora_alpha=8,
    target_modules=["k_proj", "v_proj", "q_proj", "out_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="SEQ_2_SEQ_LM"
)

peft_model = get_peft_model(model, config)

In [None]:
peft_model.print_trainable_parameters()

trainable params: 3,145,728 || all params: 462,947,328 || trainable%: 0.679500195754451


In [None]:
def preprocess_func(examples):
    source = max(examples['sources'], key = lambda x: len(x))
    source = "Summarize: " + source
    target = examples['summary/long']

    tokenized_inputs = tokenizer(source, max_length= 4096, padding='max_length', truncation=True)
    tokenized_targets = tokenizer(target, max_length= 700, padding='max_length', truncation=True)

    return {
        'input_ids': tokenized_inputs['input_ids'],
        'attention_mask': tokenized_inputs['attention_mask'],
        'labels': tokenized_targets['input_ids']
    }

def process_data(data):
    return data.remove_columns(['id','sources_metadata', 'summary/short', 'summary/tiny', 'case_metadata']).map(preprocess_func)

multi_lexsum_filtered = process_data(multi_lexsum_filtered)

Map:   0%|          | 0/850 [00:00<?, ? examples/s]

Map:   0%|          | 0/125 [00:00<?, ? examples/s]

Map:   0%|          | 0/254 [00:00<?, ? examples/s]

In [None]:
import evaluate

rouge = evaluate.load("rouge")

Downloading builder script:   0%|          | 0.00/6.27k [00:00<?, ?B/s]

In [None]:
from transformers import DataCollatorForSeq2Seq
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model = peft_model)

In [None]:
import numpy as np

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

In [None]:
import torch
torch.cuda.empty_cache()

In [None]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, GenerationConfig

generation_config = GenerationConfig(
      max_length = 700,
      num_beams=9,
      temperature = 1,
      do_sample=True,
      length_penalty=2.0,
      num_return_sequences=1,
      pad_token_id=tokenizer.pad_token_id,
      eos_token_id=tokenizer.eos_token_id,
      bos_token_id=tokenizer.bos_token_id,
      use_cache=True,
      early_stopping=True,
      no_repeat_ngram_size = 3,
    )

training_args = Seq2SeqTrainingArguments(
    output_dir="LED_multi_lexsum_peft",
    generation_config=generation_config,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size= 1,
    per_device_eval_batch_size= 1,
    save_total_limit=3,
    num_train_epochs=2,
    predict_with_generate=True,
    fp16 = False,
    push_to_hub=False
)

trainer = Seq2SeqTrainer(
    model=peft_model,
    args=training_args,
    train_dataset=multi_lexsum_filtered["train"],
    eval_dataset=multi_lexsum_filtered["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

In [None]:
trainer.train()

Epoch,Training Loss,Validation Loss


In [None]:
trainer.push_to_hub("Rudra/LED_multi_lexsum_peft")

In [None]:
trainer.evaluate(eval_dataset = multi_lexsum_filtered["test"])