In [None]:
import logging
from src.setup.hf_auth import hf_login
from src.utils.prompts import CONCISE_REWRITE_PROMPT_TEMPLATE, REFLECTION_PROMPT_TEMPLATE
from src.utils.helpers import mask_explanation, mask_explanation_fa
from src.evaluation.metrics import (
    load_model, 
    calc_perplexity_single, 
    compute_conciseness_metrics,
)
from src.evaluation import scorer_logprobs, final_prediction

def rewrite_explanation(explanation, rewrite_model):
    """Rewrite explanation using a language model and a concise prompt."""
    prompt = CONCISE_REWRITE_PROMPT_TEMPLATE.format(explanation=explanation)
    # The rewrite_model must have a `.generate_text(prompt)` method (LLM API or local model)
    return rewrite_model.generate_text(prompt)

def reflect_on_explanation(explanation, reflection_model):
    """Optional: ask model to reflect on sufficiency and logic."""
    prompt = REFLECTION_PROMPT_TEMPLATE.format(explanation=explanation)
    return reflection_model.generate_text(prompt)

def process_mcqa_sample(sample, rewrite_model, scorer_model, scorer_tokenizer, logger, mask_fa=False, reflect=False):
    """
    Pipeline for a single MCQA sample.
    sample: dict with keys 'question', 'choices', 'answer', 'verbose_explanation' (and *_fa for Farsi)
    rewrite_model: model for rewriting explanations (e.g., GPT-4)
    scorer_model, scorer_tokenizer: model/tokenizer for scoring answers
    logger: logging.Logger
    mask_fa: bool, use Farsi masking
    reflect: bool, ask for meta-reflection
    """
    # Step 1: Concise Rewrite
    concise_explanation = rewrite_explanation(
        sample["verbose_explanation"], rewrite_model
    )
    logger.info(f"Concise explanation: {concise_explanation}")

    # Step 2: Mask Choices (to avoid label leakage)
    mask_func = mask_explanation_fa if mask_fa else mask_explanation
    row = sample.copy()
    row["explanation"] = sample["verbose_explanation"]
    verbose_masked = mask_func(row)
    row["explanation"] = concise_explanation
    concise_masked = mask_func(row)
    logger.info(f"Masked verbose: {verbose_masked}")
    logger.info(f"Masked concise: {concise_masked}")

    # Step 3: Evaluate sufficiency (log-prob gap)
    suff_delta_verbose = scorer_logprobs(
        scorer_model, scorer_tokenizer, 
        sample["question"], verbose_masked, sample["choices"]
    )
    suff_delta_concise = scorer_logprobs(
        scorer_model, scorer_tokenizer, 
        sample["question"], concise_masked, sample["choices"]
    )
    logger.info(f"Sufficiency delta (verbose): {suff_delta_verbose}")
    logger.info(f"Sufficiency delta (concise): {suff_delta_concise}")

    # Step 4: Measure conciseness
    verbose_tokens, concise_tokens, reduction, percent = compute_conciseness_metrics(
        sample["verbose_explanation"], concise_explanation
    )
    logger.info(f"Tokens: verbose={verbose_tokens}, concise={concise_tokens}, reduction={reduction}, percent={percent:.2f}")

    # Step 5: Final prediction using concise explanation
    final_pred = final_prediction(
        scorer_model, scorer_tokenizer, 
        sample["question"], concise_explanation, sample["choices"]
    )
    is_correct = (final_pred == sample["answer"])
    logger.info(f"Final prediction with concise: {final_pred} (Correct? {is_correct})")

    # Step 6 (optional): Reflection
    reflection = None
    if reflect:
        reflection = reflect_on_explanation(concise_explanation, rewrite_model)
        logger.info(f"Reflection: {reflection}")

    # Step 7: Perplexity (optional, if you want to log it)
    verbose_ppl = calc_perplexity_single(sample["verbose_explanation"], scorer_tokenizer, scorer_model, logger)
    concise_ppl = calc_perplexity_single(concise_explanation, scorer_tokenizer, scorer_model, logger)

    # Aggregate results
    return {
        "concise_explanation": concise_explanation,
        "verbose_masked": verbose_masked,
        "concise_masked": concise_masked,
        "suff_delta_verbose": suff_delta_verbose,
        "suff_delta_concise": suff_delta_concise,
        "verbose_tokens": verbose_tokens,
        "concise_tokens": concise_tokens,
        "token_reduction": reduction,
        "percent_reduction": percent,
        "final_pred": final_pred,
        "is_correct": is_correct,
        "reflection": reflection,
        "verbose_ppl": verbose_ppl,
        "concise_ppl": concise_ppl
    }

def process_batch(samples, rewrite_model, scorer_model, scorer_tokenizer, logger, mask_fa=False, reflect=False):
    """Process a batch of MCQA samples and collect results as a DataFrame."""
    import pandas as pd
    results = []
    for sample in samples:
        result = process_mcqa_sample(
            sample, rewrite_model, scorer_model, scorer_tokenizer, logger, mask_fa=mask_fa, reflect=reflect
        )
        results.append(result)
    return pd.DataFrame(results)

# Usage:
# logger = logging.getLogger("pipeline")
# hf_login()  # Authenticate to HuggingFace
# scorer_tokenizer, scorer_model = load_model("YourScorerModelName", logger)
# batch_results = process_batch(samples, rewrite_model, scorer_model, scorer_tokenizer, logger)