In [None]:
!pip install rouge
!pip install farasapy
!git clone https://github.com/aub-mind/arabert
!pip install pyarabic
!pip install datasets
!pip install transformers
!pip install wandb

In [None]:
!wget "https://raw.githubusercontent.com/SalehShmali/Arabic_News_Summarization/main/articles.csv"

In [None]:
from rouge import Rouge
from datasets import load_dataset
from transformers import BertTokenizerFast,GPT2TokenizerFast, EncoderDecoderModel
from transformers import Seq2SeqTrainingArguments,Seq2SeqTrainer

In [None]:
rouge = Rouge()
batch_size = 4
encoder_max_length=512
decoder_max_length=128
arabert="aubmindlab/bert-base-arabert"
aragpt2 = "aubmindlab/aragpt2-base"

In [None]:
bert_tokenizer = BertTokenizerFast.from_pretrained(arabert)
bert_tokenizer.bos_token = bert_tokenizer.cls_token
bert_tokenizer.eos_token = bert_tokenizer.sep_token

In [None]:
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
    outputs = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
    return outputs

In [None]:
GPT2TokenizerFast.build_inputs_with_special_tokens = build_inputs_with_special_tokens
gpt2_tokenizer = GPT2TokenizerFast.from_pretrained(aragpt2)
gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token


In [None]:
def process_data_to_model_inputs(batch):                                                                                                           
    inputs = bert_tokenizer(batch["content"], padding="max_length", truncation=True, max_length=encoder_max_length)
    outputs = gpt2_tokenizer(batch["title"], padding="max_length", truncation=True, max_length=decoder_max_length)
                                                                                                        
    batch["input_ids"] = inputs.input_ids                                                              
    batch["attention_mask"] = inputs.attention_mask                                                     
    batch["decoder_input_ids"] = outputs.input_ids                                                      
    batch["labels"] = outputs.input_ids.copy() 
    batch["decoder_attention_mask"] = outputs.attention_mask
    # mask loss for padding                                                                             
    batch["labels"] = [
        [-100 if mask == 0 else token for mask, token in mask_and_tokens] for mask_and_tokens in [zip(masks, labels) for masks, labels in zip(batch["decoder_attention_mask"], batch["labels"])]
    ]

    assert all([len(x) == encoder_max_length for x in inputs.input_ids])
    assert all([len(x) == decoder_max_length for x in outputs.input_ids])
    return batch

In [None]:
def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    # all unnecessary tokens are removed
    pred_str = gpt2_tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = gpt2_tokenizer.eos_token_id
    label_str = gpt2_tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    rouge_output = rouge.get_scores(pred_str, label_str, avg=True)

    return {
        "rouge2_precision": round(rouge_output["rouge-2"]["p"], 4),
        "rouge2_recall": round(rouge_output["rouge-2"]["r"], 4),
        "rouge2_fmeasure": round(rouge_output["rouge-2"]["f"], 4),
    }

In [None]:
all_data = load_dataset("ArabicNewsSummary.py")
train_data = all_data['train'].train_test_split(test_size=0.1,seed=42)['train']
val_data = all_data['train'].train_test_split(test_size=0.1,seed=42)['test']

In [None]:
print("Length of train data",len(train_data))
print("Length of val data",len(val_data))

In [None]:
# make train dataset ready
train_data = train_data.map(
    process_data_to_model_inputs, batched=True, batch_size=batch_size, remove_columns=["content", "title"],
)
train_data.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)

# same for validation dataset
val_data = val_data.map(
    process_data_to_model_inputs, batched=True, batch_size=batch_size, remove_columns=["content", "title"],
)
val_data.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)

In [None]:
model = EncoderDecoderModel.from_encoder_decoder_pretrained(arabert, aragpt2)

In [None]:
model.config.decoder_start_token_id = gpt2_tokenizer.bos_token_id
model.config.pad_token_id = gpt2_tokenizer.eos_token_id
model.config.max_length = 128
model.config.min_length = 64
model.config.no_repeat_ngram_size = 3
model.early_stopping = True
model.length_penalty = 2.0
model.num_beams = 4

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./model",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size//2,
    predict_with_generate=True,
    do_eval=True,
    evaluation_strategy ="epoch",
    do_train=True,
    logging_steps=3963 //3,  
    save_steps= 3963 //3,  
    warmup_steps=1000,
    eval_steps=10,
    num_train_epochs=3,
    overwrite_output_dir=True,
    save_total_limit=10,
    fp16=True
)
    

In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_data,
    eval_dataset=val_data,
)

In [None]:
trainer.train()

In [None]:
eval_output = trainer.evaluate()