In [None]:
import gc
import os
import re
import traceback
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import torch
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    hamming_loss,
    jaccard_score,
    precision_score,
    recall_score,
)
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

# --- GLOBAL CONSTANTS ---

OUTPUT_DIR = "./results_multilingual_test_mistral7b_prompt_only"
PRED_DIR = os.path.join(OUTPUT_DIR, "prompt_predictions")
METRICS_DIR = os.path.join(OUTPUT_DIR, "metrics")
ALL_METRICS_FILE = os.path.join(METRICS_DIR, "combined_metrics_summary.csv")

MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"
BASE_PATH = "./XED/processed"
STRATEGIES = ["zero", "few", "instruction"]

DATA_PATHS = {
    "en": os.path.join(BASE_PATH, "test_en.csv"),
    "fr": os.path.join(BASE_PATH, "test_fr.csv"),
    "es": os.path.join(BASE_PATH, "test_es.csv"),
    "fi": os.path.join(BASE_PATH, "test_fi.csv"),
    "ml": os.path.join(BASE_PATH, "test_multilingual.csv"),
}

DATASETS = {
    lang: pd.read_csv(path) for lang, path in DATA_PATHS.items()
    if os.path.exists(path)
}

XED_EMOTION_MAPPING = {
    1: "Anger", 2: "Anticipation", 3: "Disgust", 4: "Fear",
    5: "Joy", 6: "Sadness", 7: "Surprise", 8: "Trust", 0: "Neutral"
}

# --- Multilingual Prompt Components ---

BASE_INSTRUCTIONS = {
    "en": (
        "Task: Predict the correct emotion label(s) for this text using XED "
        "label IDs (1–8, and 0 for Neutral).\n"
        "Respond ONLY with comma-separated label numbers (e.g. 1, 3)."
    ),
    "fr": (
        "Tâche: Prédire le(s) label(s) d'émotion correct(s) pour ce texte "
        "en utilisant les ID d'étiquette XED (1-8, et 0 pour Neutre).\n"
        "Répondez UNIQUEMENT avec des numéros d'étiquette séparés par "
        "des virgules (par exemple 1, 3)."
    ),
    "es": (
        "Tarea: Predecir la(s) etiqueta(s) de emoción correcta(s) para este "
        "texto usando los IDs de etiqueta XED (1-8, y 0 para Neutro).\n"
        "Responda SÓLO con números de etiqueta separados por comas "
        "(par exemple 1, 3)."
    ),
    "fi": (
        "Tehtävä: Ennusta oikea tunne-etiketti(t) tälle tekstille "
        "käyttämällä XED-etiketti-ID:itä (1-8 ja 0 Neutraalille).\n"
        "Vastaa AINOASTAAN pilkulla erotetuilla etikettinumeriilla "
        "(esim. 1, 3)."
    ),
    "default": (
        "Task: Predict the correct emotion label(s) for this text using XED "
        "label IDs (1–8, and 0 for Neutral).\n"
        "Respond ONLY with comma-separated label numbers (e.g. 1, 3)."
    ),
}

FEW_SHOT_EXAMPLES = {
    "en": (
        "Examples:\n1. Text: I am so happy today! → 5\n"
        "2. Text: This makes me furious! → 1\n"
        "3. Text: I'm scared to go outside. → 4\n"
    ),
    "fr": (
        "Exemples:\n1. Texte: Je suis si heureux aujourd'hui ! → 5\n"
        "2. Texte: Cela me rend furieux ! → 1\n"
        "3. Texte: J'ai peur de sortir. → 4\n"
    ),
    "es": (
        "Ejemplos:\n1. Texto: ¡Estoy tan feliz hoy! → 5\n"
        "2. Texto: ¡Esto me pone furioso! → 1\n"
        "3. Texto: Tengo miedo de salir. → 4\n"
    ),
    "fi": (
        "Esimerkkejä:\n1. Teksti: Olen niin onnellinen tänään! → 5\n"
        "2. Teksti: Tämä saa minut raivostumaan! → 1\n"
        "3. Teksti: Pelkään mennä ulos. → 4\n"
    ),
    "default": (
        "Examples:\n1. Text: I am so happy today! → 5\n"
        "2. Text: This makes me furious! → 1\n"
        "3. Text: I'm scared to go outside. → 4\n"
    ),
}

INSTRUCTION_INTRO = {
    "en": (
        "You are an expert emotion classifier. Use the following XED label ID "
        "to Emotion Name mapping:"
    ),
    "fr": (
        "Vous êtes un classificateur d'émotions expert. Utilisez la "
        "correspondance suivante entre l'ID d'étiquette XED et le nom "
        "de l'émotion :"
    ),
    "es": (
        "Usted es un clasificador de emociones experto. Utilice el "
        "siguiente mapeo de ID de etiqueta XED a Nombre de Emoción:"
    ),
    "fi": (
        "Olet asiantunteva tunneluokittelija. Käytä seuraavaa "
        "XED-etiketti-ID:n ja tunnenimen vastaavuutta:"
    ),
    "default": (
        "You are an expert emotion classifier. Use the following XED label ID "
        "to Emotion Name mapping:"
    ),
}

INSTRUCTION_NOTE = {
    "en": (
        "NOTE: The required output is the XED ID (0-8), NOT the emotion name. "
        "For example, 'Anger' should be '1'. Do not output ID 8 for 'Trust' "
        "as 0, as this BERT-specific re-arrangement is NOT required here."
    ),
    "fr": (
        "NOTE: La sortie requise est l'ID XED (0-8), PAS le nom de "
        "l'émotion. Par exemple, 'Colère' doit être '1'. Ne pas sortir "
        "l'ID 8 pour 'Confiance' comme 0, car ce réarrangement spécifique "
        "à BERT N'EST PAS requis ici."
    ),
    "es": (
        "NOTA: La salida requerida es l'ID XED (0-8), NO el nombre de "
        "l'émotion. Por ejemplo, 'Ira' debe ser '1'. No emita el ID 8 para "
        "'Confianza' como 0, ya que este rearreglo específico de BERT NO es "
        "requerido aquí."
    ),
    "fi": (
        "HUOMAUTUS: Vaadittu tuloste on XED ID (0-8), EI tunnenimi. "
        "Esimerkiksi 'Viha' tulee olla '1'. Älä tulosta ID 8:aa 'Luottamus' "
        "kohdalla 0:na, koska tätä BERT-spesifistä uudelleenjärjestelyä EI "
        "vaadita tässä."
    ),
    "default": (
        "NOTE: The required output is the XED ID (0-8), NOT the emotion name. "
        "For example, 'Anger' should be '1'. Do not output ID 8 for 'Trust' "
        "as 0, as this BERT-specific re-arrangement is NOT required here."
    ),
}

SENTENCE_WRAPPER = {
    "en": "Given this {lang} sentence:",
    "fr": "Étant donné cette phrase {lang}:",
    "es": "Dada esta oración {lang}:",
    "fi": "Tämä {lang} lause huomioiden:",
    "default": "Given this {lang} sentence:",
}


def parse_label_column(label_series: pd.Series) -> List[List[int]]:
    """
    Parses a pandas Series of comma-separated string labels into a list of
    lists of integers, handling NaN/empty values.
    """
    parsed_labels = []
    for val in label_series:
        if pd.isna(val):
            parsed_labels.append([])
        else:
            labels = [
                int(x.strip()) for x in str(val).split(",")
                if x.strip().isdigit()
            ]
            parsed_labels.append(labels)
    return parsed_labels


def build_prompt(text: str, lang: str, strategy: str) -> str:
    """
    Constructs the input prompt for the LLM based on the text, language,
    and prompting strategy (zero/few/instruction), using Mistral format.
    """
    lang_key = lang if lang in BASE_INSTRUCTIONS else "default"
    system_instruction = BASE_INSTRUCTIONS[lang_key]

    if strategy == "zero":
        user_query = f"Text ({lang}): {text}"

    elif strategy == "few":
        examples = FEW_SHOT_EXAMPLES[lang_key]
        user_query = f"{examples}\nText ({lang}): {text}"

    elif strategy == "instruction":
        sorted_keys = sorted(XED_EMOTION_MAPPING.keys())
        emotion_list = "\n".join(
            [f"ID {k}: {XED_EMOTION_MAPPING[k]}" for k in sorted_keys]
        )
        instruction_with_map = (
            f"{INSTRUCTION_INTRO[lang_key]}\n"
            f"{emotion_list}\n\n"
            f"{INSTRUCTION_NOTE[lang_key]}"
        )
        sentence_wrapper = SENTENCE_WRAPPER[lang_key].format(lang=lang)
        user_query = (
            f"{instruction_with_map}\n\n"
            f"{sentence_wrapper}\n"
            f'"{text}"'
        )
    else:
        user_query = text

    return f"[INST] {system_instruction}\n\n{user_query} [/INST]"


def parse_model_output(output: str) -> List[int]:
    """
    Extracts comma-separated XED ID integers from the model's raw text output.
    """
    numbers = re.findall(r"\d+", output)
    return [int(n) for n in numbers if 0 <= int(n) <= 8] if numbers else []


def compute_multilabel_metrics(
    preds: List[List[int]], golds: List[List[int]], lang: str, strategy: str
) -> Dict[str, Any]:
    """
    Computes a comprehensive set of multi-label classification metrics.
    """
    all_labels = sorted(
        {l for sublist in (preds + golds) for l in sublist if 0 <= l <= 8}
    )
    n_labels = len(all_labels)
    label_to_idx = {lbl: i for i, lbl in enumerate(all_labels)}

    if n_labels == 0:
        return {
            "language": lang, "strategy": strategy, "accuracy": 0.0,
            "precision_macro": 0.0, "recall_macro": 0.0, "f1_macro": 0.0,
            "f1_micro": 0.0, "jaccard": 0.0, "hamming": 1.0
        }

    def to_indicator_matrix(label_lists: List[List[int]]) -> np.ndarray:
        """Converts lists of labels to a binary indicator matrix."""
        mat = np.zeros((len(label_lists), n_labels), dtype=int)
        for i, labs in enumerate(label_lists):
            if isinstance(labs, (list, tuple)):
                for l in labs:
                    if l in label_to_idx:
                        mat[i, label_to_idx[l]] = 1
        return mat

    y_true = to_indicator_matrix(golds)
    y_pred = to_indicator_matrix(preds)

    metrics = {
        "language": lang,
        "strategy": strategy,
        "accuracy": accuracy_score(y_true, y_pred),
        "precision_macro": precision_score(
            y_true, y_pred, average="macro", zero_division=0
        ),
        "recall_macro": recall_score(
            y_true, y_pred, average="macro", zero_division=0
        ),
        "f1_macro": f1_score(
            y_true, y_pred, average="macro", zero_division=0
        ),
        "f1_micro": f1_score(
            y_true, y_pred, average="micro", zero_division=0
        ),
        "jaccard": jaccard_score(
            y_true, y_pred, average="samples", zero_division=0
        ),
        "hamming": hamming_loss(y_true, y_pred)
    }
    return metrics


def load_metrics_summary() -> pd.DataFrame:
    """
    Loads the existing metrics summary file or creates an empty DataFrame.
    """
    if os.path.exists(ALL_METRICS_FILE):
        try:
            return pd.read_csv(ALL_METRICS_FILE)
        except pd.errors.EmptyDataError:
            print(f"Metrics file {ALL_METRICS_FILE} is empty. Starting fresh.")
    return pd.DataFrame(columns=[
        "language", "strategy", "accuracy", "precision_macro",
        "recall_macro", "f1_macro", "f1_micro", "jaccard", "hamming"
    ])


def save_metrics_summary(df: pd.DataFrame):
    """
    Saves the updated metrics summary to the designated CSV file.
    """
    os.makedirs(METRICS_DIR, exist_ok=True)
    df.to_csv(ALL_METRICS_FILE, index=False)
    print(f"\nUpdated metric summary saved to {ALL_METRICS_FILE}")


def run_all_prompt_tests(
    pipe, all_datasets: Dict[str, pd.DataFrame], strategies: List[str]
) -> pd.DataFrame:
    """
    Runs evaluation for all defined prompt strategies (zero, few, instruction)
    on all datasets using the provided text-generation pipeline.
    """
    print("\n=== STARTING PROMPT ENGINEERING EVALUATION (MISTRAL-7B) ===")

    metrics_df = load_metrics_summary()
    batch_size = 4

    for strategy in strategies:
        print(f"\n--- Strategy: {strategy.upper()} ---")

        for lang, df in all_datasets.items():
            lang_strategy_filter = (metrics_df["language"] == lang) & (
                metrics_df["strategy"] == strategy)

            if not metrics_df[lang_strategy_filter].empty:
                print(f"Skipping {lang}-{strategy}: Metrics already exist.")
                continue

            print(f"Processing {lang}-{strategy}...")

            try:
                golds = parse_label_column(df["labels"])
                texts = df["text"].tolist()

                texts_subset = texts
                golds_subset = golds
                preds = []

                for i in tqdm(
                    range(0, len(texts_subset), batch_size),
                    desc=f"{lang}-{strategy}"
                ):
                    batch_texts = texts_subset[i:i + batch_size]
                    prompts = [
                        build_prompt(t, lang, strategy) for t in batch_texts
                    ]

                    outputs = pipe(
                        prompts,
                        max_new_tokens=16,
                        temperature=0.2,
                        batch_size=batch_size,
                        pad_token_id=pipe.tokenizer.eos_token_id,
                        return_full_text=False,
                    )

                    for r in outputs:
                        if not r or not r[0] or 'generated_text' not in r[0]:
                            text_out = ""
                        else:
                            text_out = r[0]["generated_text"]

                        preds.append(parse_model_output(text_out))

                if len(preds) != len(golds_subset):
                    print(
                        f"[WARN] Pred length mismatch for {lang}-{strategy}. "
                        f"Expected {len(golds_subset)}, got {len(preds)}. "
                        "Filling missing predictions with empty lists."
                    )
                    preds.extend([[]] * (len(golds_subset) - len(preds)))

                metrics = compute_multilabel_metrics(
                    preds, golds_subset, lang, strategy
                )
                print(f"Metrics: {metrics}")

                os.makedirs(PRED_DIR, exist_ok=True)
                pred_file = os.path.join(
                    PRED_DIR, f"preds_{lang}_{strategy}.csv"
                )

                new_preds_df = pd.DataFrame({
                    "pred": [",".join(map(str, p)) for p in preds],
                    "gold": [",".join(map(str, g)) for g in golds_subset],
                    "text": texts_subset
                })
                new_preds_df.to_csv(pred_file, index=False)
                print(f"Predictions saved: {pred_file}")

                new_row_df = pd.DataFrame([metrics])
                metrics_df = pd.concat(
                    [metrics_df, new_row_df], ignore_index=True
                )
                save_metrics_summary(metrics_df)

            except Exception as e:
                print(
                    f"\n[CRITICAL ERROR] Failed to process "
                    f"{lang}-{strategy}: {e}"
                )
                traceback.print_exc()

    return metrics_df


def main():
    """Main execution block to set up model, pipeline, and run prompt tests."""
    os.makedirs(PRED_DIR, exist_ok=True)
    os.makedirs(METRICS_DIR, exist_ok=True)

    print(f"Loading model: {MODEL_NAME}")
    global tokenizer
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side='left')

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        device_map="auto",
        torch_dtype=torch.float16,
    )

    pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        device_map="auto"
    )

    run_all_prompt_tests(pipe, DATASETS, STRATEGIES)

    del model, tokenizer, pipe
    gc.collect()
    torch.cuda.empty_cache()
    print("\nExperiment stages complete and resources cleaned up.")


if __name__ == "__main__":
    main()