In [None]:
import json
import random
import nltk
import torch
import evaluate
import pandas as pd
from transformers import EarlyStoppingCallback
from sklearn.model_selection import train_test_split
from datasets import Dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq,
)
import numpy as np
from collections import defaultdict
from tqdm import tqdm

# Download necessary NLTK resources
nltk.download('punkt')
nltk.download('wordnet')
nltk.download('omw-1.4')

# Load evaluation metrics
rouge = evaluate.load("rouge")
meteor = evaluate.load("meteor")
bleu = evaluate.load("bleu")


# Model and tokenizer
model_name = "google/flan-t5-large"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Dataset configurations
dataset_configs = [
    {"name": "Zero-Shot", "path": "Zero-Shot_qa_dataset.json"},
    {"name": "One-Shot", "path": "One-Shot_qa_dataset.json"},
    {"name": "Few-Shot", "path": "Few-Shot_qa_dataset.json"},
]

# Preprocessing function
def preprocess(example):
    model_input = tokenizer(
        example["input"], max_length=256, padding="max_length", truncation=True
    )
    labels = tokenizer(
        example["output"], max_length=128, padding="max_length", truncation=True
    )
    model_input["labels"] = labels["input_ids"]
    return model_input

# Compute metrics function
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.clip(predictions, 0, tokenizer.vocab_size - 1)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    decoded_preds = [p.strip() for p in decoded_preds]
    decoded_labels = [l.strip() for l in decoded_labels]

    rouge_scores = rouge.compute(predictions=decoded_preds, references=decoded_labels)
    meteor_score = meteor.compute(predictions=decoded_preds, references=decoded_labels)
    bleu_score = bleu.compute(predictions=decoded_preds, references=[[l] for l in decoded_labels])

    return {
        "rouge1": rouge_scores["rouge1"],
        "rouge2": rouge_scores["rouge2"],
        "rougeL": rouge_scores["rougeL"],
        "meteor": meteor_score["meteor"],
        "bleu": bleu_score["bleu"]
    }

# Compute category-wise metrics
def compute_category_metrics(predictions, references, categories):
    category_data = defaultdict(lambda: {'preds': [], 'refs': []})
    for pred, ref, cat in zip(predictions, references, categories):
        category_data[cat]['preds'].append(pred)
        category_data[cat]['refs'].append(ref)
    category_results = {}
    for cat, data in category_data.items():
        try:
            rouge_scores = rouge.compute(predictions=data['preds'], references=data['refs'])
            meteor_score = meteor.compute(predictions=data['preds'], references=data['refs'])
            bleu_score = bleu.compute(predictions=data['preds'], references=[[r] for r in data['refs']])
            category_results[cat] = {
                'count': len(data['preds']),
                'rouge1': rouge_scores["rouge1"],
                'rouge2': rouge_scores["rouge2"],
                'rougeL': rouge_scores["rougeL"],
                'meteor': meteor_score["meteor"],
                'bleu': bleu_score["bleu"]
            }
        except Exception as e:
            print(f"Error in category {cat}: {e}")
    return category_results

# Full training loop
def run_finetuning(dataset_name, dataset_path):
    print(f"\n==============================")
    print(f"\U0001F680 Processing {dataset_name}")
    print(f"==============================")

    with open(dataset_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    examples = []
    for item in data["queries"]:
        question = item["question"]
        answer = item["answer"]
        context = item.get("source_context", "")
        question_type = item.get("question_type", "unknown")
        triples = item.get("ground_truth", {}).get("source_triples", [])
        kg_text = "; ".join([
            f"{t['subject']} - {t['predicate']} - {t['object']}" for t in triples
        ]) if triples else ""
        input_text = (
            f"Answer the following {question_type} question:\n"
            f"Question: {question}\n"
            f"Context: {context}\n"
            f"Knowledge: {kg_text}"
        )
        examples.append({
            "input": input_text,
            "output": answer,
            "question": question,
            "category": question_type.lower()
        })

    train_data, test_data = train_test_split(examples, test_size=0.3, random_state=42)
    val_data, test_data = train_test_split(test_data, test_size=1/3, random_state=42)
    dataset = DatasetDict({
        "train": Dataset.from_list(train_data),
        "validation": Dataset.from_list(val_data),
        "test": Dataset.from_list(test_data),
    })

    tokenized_dataset = dataset.map(preprocess, batched=False)
    training_args = Seq2SeqTrainingArguments(
        output_dir=f"./flant5-large-finetuned-{dataset_name.lower()}",
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        num_train_epochs=25,
        learning_rate=3e-4,
        logging_dir="./logs",
        logging_steps=20,
        save_steps=200,
        eval_strategy="steps",
        eval_steps=200,
        save_total_limit=2,
        load_best_model_at_end=True,
        metric_for_best_model="rouge1",
        greater_is_better=True,
        predict_with_generate=True,
        generation_max_length=128,
        report_to="none",
    )

    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset["train"],
        eval_dataset=tokenized_dataset["validation"],
        tokenizer=tokenizer,
        data_collator=DataCollatorForSeq2Seq(tokenizer, model=model),
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
    )

    trainer.train()

    print(f"\nEvaluating {dataset_name} test set...")
    test_results = trainer.evaluate(tokenized_dataset["test"])
    for key, value in test_results.items():
        if isinstance(value, float):
            print(f"  {key.upper()}: {value:.4f}")

    print(f"\nCategory-wise Results for {dataset_name}...")
    model.eval()
    device = model.device
    all_predictions, all_references, all_categories = [], [], []

    for example in tqdm(test_data):
        inputs = tokenizer(
            example["input"], return_tensors="pt", truncation=True, padding=True, max_length=256
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}
        outputs = model.generate(**inputs, max_length=128, num_beams=4, early_stopping=True)
        prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
        all_predictions.append(prediction.strip())
        all_references.append(example["output"].strip())
        all_categories.append(example["category"])

    category_metrics = compute_category_metrics(all_predictions, all_references, all_categories)
    df = pd.DataFrame.from_dict(category_metrics, orient="index")
    df.to_csv(f"category_metrics_{dataset_name.lower()}.csv")
    print(df.to_string())

    trainer.save_model(f"./flan-t5-large-finetuned-{dataset_name.lower()}")
    tokenizer.save_pretrained(f"./flan-t5-large-finetuned-{dataset_name.lower()}")
    print(f"\n {dataset_name} model and tokenizer saved successfully.\n")


# Run all datasets
for config in dataset_configs:
    run_finetuning(config["name"], config["path"])