# %% [markdown]
# # Model Evaluation and Benchmarking
# 
# Comprehensive evaluation:
# - Perplexity measurement
# - Generation quality metrics
# - Persona consistency scoring
# - Human evaluation templates
# - A/B testing framework

In [None]:
# %%
# Import libraries and configuration
import os
import torch
import numpy as np
import pandas as pd
from datasets import load_from_disk
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# Configuration
BASE_MODEL = "deepseek-ai/DeepSeek-V3-Base"
FINETUNED_MODEL = "../models/sft_lora"
TEST_DATASET = "../data/processed/sft_dataset"
RESULTS_DIR = "../evaluation_results"

os.makedirs(RESULTS_DIR, exist_ok=True)

# %% [markdown]
# ## Load Models


In [None]:
# %%
# Load base model
print("Loading base model...")
base_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL, device_map="auto", torch_dtype=torch.float16
)

# Load fine-tuned model
print("Loading fine-tuned model...")
ft_tokenizer = AutoTokenizer.from_pretrained(FINETUNED_MODEL)
ft_base = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL, device_map="auto", torch_dtype=torch.float16
)
ft_model = PeftModel.from_pretrained(ft_base, FINETUNED_MODEL)

print("Models loaded successfully")

# %% [markdown]
# ## Perplexity Evaluation


In [None]:
# %%
def calculate_perplexity(model, tokenizer, texts, batch_size=8):
    """
    Calculate perplexity on a list of texts
    """
    model.eval()
    total_loss = 0
    total_tokens = 0

    with torch.no_grad():
        for i in tqdm(range(0, len(texts), batch_size)):
            batch = texts[i : i + batch_size]

            encodings = tokenizer(
                batch,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=2048,
            ).to(model.device)

            outputs = model(**encodings, labels=encodings["input_ids"])
            loss = outputs.loss

            # Calculate number of actual tokens (excluding padding)
            n_tokens = (encodings["attention_mask"].sum()).item()

            total_loss += loss.item() * n_tokens
            total_tokens += n_tokens

    avg_loss = total_loss / total_tokens
    perplexity = torch.exp(torch.tensor(avg_loss)).item()

    return perplexity


# Load test data
dataset = load_from_disk(TEST_DATASET)
test_texts = dataset["validation"]["text"][:100]  # Sample for faster evaluation

print("Calculating perplexity...\n")

# Base model perplexity
base_ppl = calculate_perplexity(base_model, base_tokenizer, test_texts)
print(f"Base Model Perplexity: {base_ppl:.2f}")

# Fine-tuned model perplexity
ft_ppl = calculate_perplexity(ft_model, ft_tokenizer, test_texts)
print(f"Fine-tuned Model Perplexity: {ft_ppl:.2f}")

# Improvement
improvement = ((base_ppl - ft_ppl) / base_ppl) * 100
print(f"\nImprovement: {improvement:.2f}%")

# %% [markdown]
# ## Generation Quality Metrics

In [None]:
# %%
from rouge import Rouge
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import nltk

nltk.download("punkt", quiet=True)


def evaluate_generation_quality(model, tokenizer, test_pairs, num_samples=50):
    """
    Evaluate generation quality using ROUGE and BLEU
    """
    rouge = Rouge()
    smoothing = SmoothingFunction().method1

    rouge_scores = []
    bleu_scores = []

    model.eval()

    for i, (prompt, reference) in enumerate(tqdm(test_pairs[:num_samples])):
        # Generate
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=200,
                temperature=0.7,
                do_sample=True,
                pad_token_id=tokenizer.pad_token_id,
            )

        generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
        generated = generated[len(prompt) :].strip()

        # ROUGE
        try:
            rouge_score = rouge.get_scores(generated, reference)[0]
            rouge_scores.append(rouge_score["rouge-l"]["f"])
        except:
            pass

        # BLEU
        reference_tokens = nltk.word_tokenize(reference)
        generated_tokens = nltk.word_tokenize(generated)
        bleu = sentence_bleu(
            [reference_tokens], generated_tokens, smoothing_function=smoothing
        )
        bleu_scores.append(bleu)

    return {"rouge_l": np.mean(rouge_scores), "bleu": np.mean(bleu_scores)}


# Prepare test pairs (prompt, reference)
test_pairs = []  # Load from dataset
# For demo purposes:
test_pairs = [
    ("Tell me about Elio.", "Elio is a curious Earth kid who joined the Communiverse."),
    # Add more...
]

if len(test_pairs) > 0:
    print("\nEvaluating generation quality...\n")

    base_metrics = evaluate_generation_quality(base_model, base_tokenizer, test_pairs)
    ft_metrics = evaluate_generation_quality(ft_model, ft_tokenizer, test_pairs)

    print("Base Model:")
    print(f"  ROUGE-L: {base_metrics['rouge_l']:.4f}")
    print(f"  BLEU: {base_metrics['bleu']:.4f}")

    print("\nFine-tuned Model:")
    print(f"  ROUGE-L: {ft_metrics['rouge_l']:.4f}")
    print(f"  BLEU: {ft_metrics['bleu']:.4f}")

# %% [markdown]
# ## Side-by-Side Comparison

In [None]:
# %%
def compare_responses(prompt, base_model, base_tokenizer, ft_model, ft_tokenizer):
    """
    Generate and compare responses from both models
    """
    print(f"Prompt: {prompt}\n")
    print("=" * 60)

    # Base model
    inputs = base_tokenizer(prompt, return_tensors="pt").to(base_model.device)
    with torch.no_grad():
        outputs = base_model.generate(
            **inputs, max_new_tokens=150, temperature=0.7, do_sample=True
        )
    base_response = base_tokenizer.decode(outputs[0], skip_special_tokens=True)[
        len(prompt) :
    ]

    print("\nBase Model Response:")
    print(base_response.strip())
    print("\n" + "-" * 60)

    # Fine-tuned model
    inputs = ft_tokenizer(prompt, return_tensors="pt").to(ft_model.device)
    with torch.no_grad():
        outputs = ft_model.generate(
            **inputs, max_new_tokens=150, temperature=0.7, do_sample=True
        )
    ft_response = ft_tokenizer.decode(outputs[0], skip_special_tokens=True)[
        len(prompt) :
    ]

    print("\nFine-tuned Model Response:")
    print(ft_response.strip())
    print("\n" + "=" * 60)


# Test prompts
comparison_prompts = [
    "What is the Communiverse?",
    "Tell me about Elio's journey.",
    "Who is Glordon and what role does he play?",
]

print("\nSide-by-Side Comparison:\n")
for prompt in comparison_prompts:
    compare_responses(prompt, base_model, base_tokenizer, ft_model, ft_tokenizer)
    print("\n")

# %% [markdown]
# ## Save Evaluation Results

In [None]:
# %%
# Compile all results
evaluation_results = {
    "model_info": {"base_model": BASE_MODEL, "finetuned_model": FINETUNED_MODEL},
    "perplexity": {
        "base": base_ppl,
        "finetuned": ft_ppl,
        "improvement_percent": improvement,
    },
    "generation_quality": {
        "base": base_metrics if "base_metrics" in locals() else None,
        "finetuned": ft_metrics if "ft_metrics" in locals() else None,
    },
}

# Save to JSON
import json

with open(f"{RESULTS_DIR}/evaluation_results.json", "w") as f:
    json.dump(evaluation_results, f, indent=2)

print(f"\nEvaluation results saved to: {RESULTS_DIR}/evaluation_results.json")

# %% [markdown]
# ## Visualization

In [None]:
# %%
# Create comparison chart
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Perplexity comparison
models = ["Base Model", "Fine-tuned"]
ppls = [base_ppl, ft_ppl]
axes[0].bar(models, ppls, color=["#ff7f0e", "#2ca02c"])
axes[0].set_ylabel("Perplexity (lower is better)")
axes[0].set_title("Perplexity Comparison")
axes[0].set_ylim(0, max(ppls) * 1.2)

# Add values on bars
for i, v in enumerate(ppls):
    axes[0].text(i, v + max(ppls) * 0.02, f"{v:.2f}", ha="center")

# Generation quality comparison (if available)
if "base_metrics" in locals():
    metrics = ["ROUGE-L", "BLEU"]
    base_vals = [base_metrics["rouge_l"], base_metrics["bleu"]]
    ft_vals = [ft_metrics["rouge_l"], ft_metrics["bleu"]]

    x = np.arange(len(metrics))
    width = 0.35

    axes[1].bar(x - width / 2, base_vals, width, label="Base Model", color="#ff7f0e")
    axes[1].bar(x + width / 2, ft_vals, width, label="Fine-tuned", color="#2ca02c")

    axes[1].set_ylabel("Score (higher is better)")
    axes[1].set_title("Generation Quality Metrics")
    axes[1].set_xticks(x)
    axes[1].set_xticklabels(metrics)
    axes[1].legend()
    axes[1].set_ylim(0, 1)

plt.tight_layout()
plt.savefig(f"{RESULTS_DIR}/evaluation_comparison.png", dpi=300, bbox_inches="tight")
plt.show()

print(f"\nVisualization saved to: {RESULTS_DIR}/evaluation_comparison.png")