#1. Import the library


In [None]:
!pip install transformers
!pip install datasets
!pip install rouge_score
!pip install wandb

In [None]:
from datasets import load_dataset, load_metric, Dataset
from transformers import AutoModelForSeq2SeqLM
from transformers import AutoTokenizer
import nltk
import datasets
import numpy as np
nltk.download("punkt", quiet=True)
from transformers import (
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
)


#2. Import the datasets

In [None]:
train_dataset = load_dataset("LA1512/train_pubmed_ORC_4096_20k")["train"]
val_dataset = load_dataset("LA1512/val_pubmed_ORC_4096_1592")["train"]

In [None]:

model_name = "pszemraj/led-base-book-summary"
tokenizer = AutoTokenizer.from_pretrained(model_name)





#3. Tokenize the text

In [None]:
max_input_length = 4096 # demo
max_output_length = 512
batch_size = 2

In [None]:
def process_data_to_model_inputs(batch):
    # tokenize the inputs and labels
    inputs = tokenizer(
        batch["article"],
        padding="max_length",
        truncation=True,
        max_length=max_input_length,
    )
    outputs = tokenizer(
        batch["abstract"],
        padding="max_length",
        truncation=True,
        max_length=max_output_length,
    )

    batch["input_ids"] = inputs.input_ids
    batch["attention_mask"] = inputs.attention_mask

    # create 0 global_attention_mask lists
    batch["global_attention_mask"] = len(batch["input_ids"]) * [
        [0 for _ in range(len(batch["input_ids"][0]))]
    ]

    # since above lists are references, the following line changes the 0 index for all samples
    batch["global_attention_mask"][0][0] = 1
    batch["labels"] = outputs.input_ids

    # We have to make sure that the PAD token is ignored
    batch["labels"] = [
        [-100 if token == tokenizer.pad_token_id else token for token in labels]
        for labels in batch["labels"]
    ]

    return batch

In [None]:
train_dataset = train_dataset.map(
    process_data_to_model_inputs,
    batched=True,
    batch_size=batch_size,
    remove_columns=["article", "abstract", "section_names", "article_CS", "ext_target"],
)

In [None]:
val_dataset = val_dataset.map(
    process_data_to_model_inputs,
    batched=True,
    batch_size=batch_size,
    remove_columns=["article", "abstract", "section_names", "article_CS", "ext_target"],
)

In [None]:
train_dataset.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "global_attention_mask", "labels"],
)
val_dataset.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "global_attention_mask", "labels"],
)

#4. Define the model

In [None]:
led = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_cache=False)

In [None]:
# set generate hyperparameters
led.config.num_beams = 2
led.config.max_length = 512
led.config.min_length = 100
led.config.length_penalty = 2.0
led.config.early_stopping = True
led.config.no_repeat_ngram_size = 3

#5. Set-up fune-tuning

In [None]:
import wandb
wandb.login(key="your_wandb_API-key")
wandb.init(project="NLP-project", name="name of model")



In [None]:
def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels


def compute_metrics(eval_preds):

    metric = datasets.load_metric("rouge")
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    preds = np.array(preds)
    preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.array(labels)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(
        predictions=decoded_preds, references=decoded_labels, use_stemmer=True
    )
    # Extract a few results from ROUGE
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}

    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
    ]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result


In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="results",
    num_train_epochs=2,  # demo
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=1,  # demo
    per_device_eval_batch_size=1,
    learning_rate=5e-5,
    warmup_steps=700,
    weight_decay=0.1,
    label_smoothing_factor=0.1,
    predict_with_generate=True,
    generation_num_beams = 6,
    logging_dir="logs",
    logging_steps=5,
    save_total_limit=8,
    save_steps = 400,
    gradient_accumulation_steps=16,
    prediction_loss_only = True,
    evaluation_strategy ="steps",
    eval_steps = 400,
    load_best_model_at_end = True,
    
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=led,max_length = 512)

trainer = Seq2SeqTrainer(
    model=led,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [None]:
import gc
import torch
torch.cuda.empty_cache()
gc.collect()

In [None]:
# #training
trainer.train()

In [None]:
# import gc
# torch.cuda.empty_cache()
# gc.collect()

In [None]:
# trainer.evaluate()

#6. Push model into the hub

In [None]:
!pip install huggingface_hub --q


In [None]:
!huggingface-cli login --token "Your_Hugginngface_Acssing_token"

In [None]:
trainer.push_to_hub("hub_name")