In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

In [None]:
# Install required packages
!pip install -qU transformers datasets wandb accelerate huggingface_hub

In [None]:
# Import libraries
import numpy as np
import pandas as pd
import torch
from transformers import (
    AutoModelForSeq2SeqLM, 
    AutoTokenizer, 
    Seq2SeqTrainer, 
    Seq2SeqTrainingArguments, 
    DataCollatorForSeq2Seq
)
from datasets import DatasetDict, Dataset
# from huggingface_hub import notebook_login
import wandb

In [None]:
# # Login to Hugging Face Hub
# notebook_login()

wandb.login(key='13f100cc8cb70b16c3bd0006728cbca396156c9a')

In [None]:
# Load dataset function
def load_data(train_path, val_path):
    print("🔄 Loading datasets...")
    train_data = pd.read_json(train_path)
    val_data = pd.read_json(val_path)
    
    train_data.rename(columns={"paragraph": "context"}, inplace=True)
    val_data.rename(columns={"paragraph": "context"}, inplace=True)
    
    print("\n📊 Training Data Sample:")
    print(train_data.head(1).T)
    print("\n📊 Validation Data Sample:")
    print(val_data.head(1).T)
    
    return train_data, val_data

train_path = "/kaggle/input/train.json"
val_path = "/kaggle/input/val.json"

train_data, val_data = load_data(train_path, val_path)

In [None]:
# Format dataset
def format_for_qa(dataset, dataset_name="train"):
    print(f"\n🔍 Formatting {dataset_name} dataset...")
    formatted_text = dataset["question"] + " <sep> " + dataset["context"]
    
    print("\n📝 Sample Formatted Input (first 200 chars):")
    for i in range(3):
        print(f"Sample {i+1}: {formatted_text.iloc[i][:200]}...")
    
    return pd.DataFrame({
        "text": formatted_text,
        "answer": dataset["answer"]
    })

train_qa = format_for_qa(train_data, "train")
val_qa = format_for_qa(val_data, "validation")

In [None]:
# Convert to DatasetDict
datasets = DatasetDict({
    "train": Dataset.from_pandas(train_qa),
    "validation": Dataset.from_pandas(val_qa)
})

In [None]:
# Load AraBART-finetuned-wiki-ar

print("\n🔄 Loading  model...")
model_name = "google/mt5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# Add <sep> token
special_tokens = {'additional_special_tokens': ['<sep>']}
tokenizer.add_special_tokens(special_tokens)
model.resize_token_embeddings(len(tokenizer))

# Verify tokenization
sample_text = "السؤال: متى تأسست السعودية؟ <sep> الجواب: 1932 <sep> السياق: تأسست المملكة..."
print("\n🔍 Tokenization Test:")
print("Tokens:", tokenizer.tokenize(sample_text)[:20])
print("<sep> token ID:", tokenizer.convert_tokens_to_ids("<sep>"))

In [None]:
# Tokenization function optimized for AraBART
def tokenize_function(examples):
    # Tokenize inputs
    model_inputs = tokenizer(
        examples["text"],
        max_length=512,
        truncation=True,
        padding="max_length",
        return_tensors="pt"
    )
    
    # Tokenize labels
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            examples["answer"],
            max_length=128,
            truncation=True,
            padding="max_length",
            return_tensors="pt"
        )
    
    model_inputs["labels"] = labels["input_ids"]
    
    # # Debug sample
    # if len(model_inputs["input_ids"]) > 0:
    #     print("\n🔢 Sample Tokenized:")
    #     print("Input:", tokenizer.decode(model_inputs["input_ids"][0][:]))
    #     print("Label:", tokenizer.decode(labels["input_ids"][0][:10]))
    
    return model_inputs

tokenized_datasets = datasets.map(tokenize_function, batched=True, batch_size=8, remove_columns=["text", "answer"])

In [None]:
# Data collator
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    padding=True,
    label_pad_token_id=-100
)

In [None]:
# Training arguments optimized for AraBART
training_args = Seq2SeqTrainingArguments(
    output_dir="./arabart-qa",  # Not used for saving during training anymore
    eval_strategy="no",
    save_strategy="no",                       # ✅ Don't save during training
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    predict_with_generate=True,
    num_train_epochs=12,
    learning_rate=3e-5,
    weight_decay=0.01,
    generation_max_length=128,
    generation_num_beams=4,
    load_best_model_at_end=True,
    report_to="none",
    
    fp16=True
)


In [None]:
# Initialize trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)

In [None]:
# Train the model
print("\n🚀 Starting training...")
trainer.train()

In [None]:
# # Save model
trainer.save_model("mt5_qa_sep")
tokenizer.save_pretrained("mt5_qa_sep")

In [None]:
# Enhanced evaluation function
def evaluate_qa(model, tokenizer, dataset, num_examples=5):
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    for i in range(num_examples):
        input_text = f"{dataset['question'][i]} <sep> {dataset['answer'][i]} <sep> {dataset['context'][i]}"
        
        inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to(device)
        
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_length=128,
            num_beams=4,
            early_stopping=True,
            temperature=0.9
        )
        
        print(f"\n📄 Example {i+1}:")
        print("🔍 Question:", dataset["question"][i])
        print("📜 Context:", dataset["context"][i][:100] + "...")
        print("✅ Correct:", dataset["answer"][i])
        print("🤖 Predicted:", tokenizer.decode(outputs[0], skip_special_tokens=True))
        print("-"*80)

print("\n🔍 Evaluating model...")
evaluate_qa(model, tokenizer, val_data)