In [None]:
from datasets import load_dataset, load_metric

In [None]:
train_dataset = load_dataset("zedfum/long-summarization-persian", split="train")

In [None]:
val_dataset = load_dataset("zedfum/long-summarization-persian", split="validation[:10%]")

In [None]:
from transformers import AutoTokenizer

In [None]:
tokenizer = AutoTokenizer.from_pretrained("zedfum/arman-longformer-8k")

In [None]:
max_input_length = 8192
max_output_length = 512
batch_size = 1
batch_size_eval=2

In [None]:
def process_data_to_model_inputs(batch):
    # tokenize the inputs and labels
    inputs = tokenizer(
        batch["article"],
        padding="max_length",
        truncation=True,
        max_length=max_input_length,
    )
    outputs = tokenizer(
        batch["summary"],
        padding="max_length",
        truncation=True,
        max_length=max_output_length,
    )

    batch["input_ids"] = inputs.input_ids
    batch["attention_mask"] = inputs.attention_mask

    # create 0 global_attention_mask lists
    batch["global_attention_mask"] = len(batch["input_ids"]) * [
        [0 for _ in range(len(batch["input_ids"][0]))]
    ]

    # since above lists are references, the following line changes the 0 index for all samples
    batch["global_attention_mask"][0][0] = 1
    batch["labels"] = outputs.input_ids

    # We have to make sure that the PAD token is ignored
    batch["labels"] = [
        [-100 if token == tokenizer.pad_token_id else token for token in labels]
        for labels in batch["labels"]
    ]

    return batch

In [None]:
train_dataset = train_dataset.map(
    process_data_to_model_inputs,
    batched=True,
    batch_size=batch_size,
    remove_columns=["article", "summary", "Unnamed: 0","id"],
)

In [None]:
val_dataset = val_dataset.map(
    process_data_to_model_inputs,
    batched=True,
    batch_size=batch_size_eval,
    remove_columns=["article", "summary", "Unnamed: 0","id"],
)

In [None]:
train_dataset.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "global_attention_mask", "labels"],
)
val_dataset.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "global_attention_mask", "labels"],
)

In [None]:
from transformers import AutoModelForSeq2SeqLM

In [None]:
led = AutoModelForSeq2SeqLM.from_pretrained("zedfum/arman-longformer-8k", gradient_checkpointing=True, use_cache=False)

In [None]:
# set generate hyperparameters
led.config.num_beams = 2
led.config.max_length = 512
led.config.min_length = 100
led.config.length_penalty = 2.0
led.config.early_stopping = True
led.config.no_repeat_ngram_size = 3

In [None]:
import bert_score
from rouge import Rouge

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    P, R, F1 = bert_score.score(pred_str, label_str, lang="fa")
    rouge = Rouge()
    scores = rouge.get_scores(pred_str, label_str)

    return {
        "bert_precision": round(P.mean(), 4),
        "bert_recall": round(R.mean(), 4),
        "bert_fmeasure": round(F1.mean(), 4),
    }

In [None]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, TrainingArguments

In [None]:
from torch import nn



training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="steps",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size_eval,
    fp16=True,
    output_dir="./checks",
    logging_steps=5,
    eval_steps=4000,
    save_steps=4000,
    optim='adafactor',
    save_total_limit=2,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,

    hub_token="hub_token",
    push_to_hub=True,
    hub_model_id="zedfum/arman-longformer-8k-finetuned-ensani",
    hub_strategy="checkpoint",
)



In [None]:
trainer = Seq2SeqTrainer(
    model=led,
    tokenizer=tokenizer,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

In [None]:
import os
resume_from_checkpoint=False
if len(os.listdir(training_args.output_dir))>3:
  if os.path.exists(f'{training_args.output_dir}/last-checkpoint'):
    resume_from_checkpoint=f'{training_args.output_dir}/last-checkpoint'
  else:
    resume_from_checkpoint=True

# start training
trainer.train(resume_from_checkpoint=resume_from_checkpoint)