In [None]:
import os
import re
import evaluate
import numpy as np
import pandas as pd
from typing import List, Tuple, Dict, Any

import torch
from datasets import Dataset
from transformers import (
    AutoTokenizer, AutoModelForSeq2SeqLM, 
    TrainingArguments, Trainer, DataCollatorForSeq2Seq
)
from peft import LoraConfig, get_peft_model, PeftModel
from sklearn.metrics import f1_score, jaccard_score, hamming_loss
from tqdm import tqdm
import gc

BASE_PATH = "XED/processed"
MODEL_NAME = "google/flan-t5-base"
LANGUAGES = ["en", "fi", "fr", "es"]
SAVE_ROOT_DIR = "./weights"
RESULTS_FILE = "./results_lora_flan_t5_base_5_lang.csv"
TRAIN_ARGS_ROOT_DIR = "./results_lora"
CV_RESULTS_DIR = "./cv_results"
CV_WEIGHTS_DIR = "./cv_weights"

MAX_SAMPLES_EVAL = 500
CV_EPOCHS = 2  
FINAL_EPOCHS = 5 
BATCH_SIZE = 8
LEARNING_RATE = 2e-4
RANDOM_STATE = 42

CV_LANG = "en"
LORA_GRID = {
    "r": [8, 16, 32],
    "lora_alpha": [16, 32, 64]
}

# Mapping for metric calculation (IDs 1-8 are typically used for emotions)
ALL_LABELS = [str(i) for i in range(1, 9)]

# --- Model and Tokenizer Initialization ---
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
BASE_MODEL_CONFIG = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).config

# --- Helper Functions ---

def load_data(lang: str) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """Load train and test dataframes for a specific language."""
    train_path = os.path.join(BASE_PATH, f"train_{lang}.csv")
    test_path = os.path.join(BASE_PATH, f"test_{lang}.csv")
    train_df = pd.read_csv(train_path)
    test_df = pd.read_csv(test_path)
    return train_df, test_df

def preprocess_function(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
    """
    Tokenizes text and labels, preparing data for Seq2Seq training.
    """
    labels_list = [str(l) for l in examples.get("labels", [])]
    
    texts = [f"emotion classification: {t}" for t in examples["text"]]
    
    model_inputs = tokenizer(texts, max_length=256, truncation=True)
    
    labels = tokenizer(labels_list, max_length=64, truncation=True)
    model_inputs["labels"] = labels["input_ids"]

    return model_inputs

def compute_metrics_numeric(preds: List[str], golds: List[str]) -> Dict[str, float]:
    """
    Compute F1/Jaccard/Hamming metrics by converting label strings 
    (e.g., '1,4') to binary indicator matrices.
    """
    y_true = np.zeros((len(golds), len(ALL_LABELS)))
    y_pred = np.zeros((len(golds), len(ALL_LABELS)))

    for i, (g, p) in enumerate(zip(golds, preds)):
        # Extract and clean numeric IDs from gold string
        true_ids = [s.strip() for s in str(g).split(",") if s.strip().isdigit()]
        
        # Extract numeric IDs robustly from predicted text (LLM output)
        pred_ids = re.findall(r'\d+', str(p)) 
        
        # Populate true matrix
        for t in true_ids:
            if t in ALL_LABELS:
                y_true[i, ALL_LABELS.index(t)] = 1
        
        # Populate predicted matrix
        for t in pred_ids:
            if t in ALL_LABELS:
                y_pred[i, ALL_LABELS.index(t)] = 1

    metrics = {
        "micro_f1": f1_score(y_true, y_pred, average="micro", zero_division=0),
        "macro_f1": f1_score(y_true, y_pred, average="macro", zero_division=0),
        "jaccard": jaccard_score(y_true, y_pred, average="samples", zero_division=0),
        "hamming": hamming_loss(y_true, y_pred)
    }
    return metrics

def predict_emotions_llm(
        model: PeftModel, 
        tokenizer: AutoTokenizer, 
        df: pd.DataFrame, 
        lang: str
        ) -> Tuple[List[str], List[str]]:
    """Generate model predictions using the current model (LoRA adapter) for evaluation."""
    model.eval()
    
    preds: List[str] = []
    golds: List[str] = []
    
    df_subset = df.head(MAX_SAMPLES_EVAL)

    device = model.device
    
    for _, row in tqdm(df_subset.iterrows(), total=len(df_subset), desc=f"Evaluating {lang}"):
        text = f"emotion classification: {row['text']}" 
        
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256).to(device)
        
        with torch.no_grad():
            output = model.generate(**inputs, max_new_tokens=20)
            
        pred_text = tokenizer.decode(output[0], skip_special_tokens=True).lower()
        preds.append(pred_text)
        golds.append(str(row["labels"]))

    return preds, golds

def run_peft_training_and_evaluation(
    lang: str,
    peft_model: PeftModel,
    tokenizer: AutoTokenizer,
    training_data: pd.DataFrame,
    test_data: pd.DataFrame,
    epochs: int,
    is_multilingual: bool = False
) -> Dict[str, Any]:
    """
    Performs LoRA fine-tuning for a single language split and evaluates it.
    
    Returns a dictionary of metrics.
    """
    lang_tag = lang if not is_multilingual else "ml"
    print(f"\n--- Starting Training: {lang_tag.upper()} ({epochs} epochs) ---")

    train_dataset = Dataset.from_pandas(training_data.reset_index(drop=True))
    test_dataset = Dataset.from_pandas(test_data.reset_index(drop=True))

    num_proc = os.cpu_count() or 1
    tokenized_train = train_dataset.map(
        preprocess_function, batched=True, 
        remove_columns=train_dataset.column_names, 
        num_proc=num_proc
        )
    tokenized_test = test_dataset.map(
        preprocess_function, batched=True, 
        remove_columns=test_dataset.column_names, 
        num_proc=num_proc
        )
    
    output_dir_path = os.path.join(TRAIN_ARGS_ROOT_DIR, f"results_lora_{lang_tag}")
    os.makedirs(output_dir_path, exist_ok=True)
    
    training_args = TrainingArguments(
        output_dir=output_dir_path,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,
        learning_rate=LEARNING_RATE,
        num_train_epochs=epochs,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        logging_dir="./logs",
        report_to="none",
        seed=RANDOM_STATE,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
    )

    data_collator = DataCollatorForSeq2Seq(tokenizer, model=peft_model)

    trainer = Trainer(
        model=peft_model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_test,
        tokenizer=tokenizer,
        data_collator=data_collator
    )

    trainer.train()

    save_dir = os.path.join(SAVE_ROOT_DIR, f"lora_{lang_tag}")
    peft_model.save_pretrained(save_dir)
    tokenizer.save_pretrained(save_dir)
    
    print(f"Adapter for {lang_tag.upper()} saved to {save_dir}")

    device = peft_model.device
    base_model_eval = AutoModelForSeq2SeqLM.from_pretrained(
        MODEL_NAME, 
        config=BASE_MODEL_CONFIG
        ).to(device)
    final_model = PeftModel.from_pretrained(
        base_model_eval, 
        save_dir
        )

    preds, golds = predict_emotions_llm(final_model, tokenizer, test_data, lang_tag)
    final_metrics = compute_metrics_numeric(preds, golds)
    
    final_metrics["language"] = lang_tag
    print(f"Evaluation Metrics for {lang_tag.upper()}:")
    for k, v in final_metrics.items():
        if k != 'language':
            print(f"  {k}: {v:.4f}")
            
    del base_model_eval, final_model, trainer
    torch.cuda.empty_cache()
    gc.collect()
    
    return final_metrics

# --- Cross-Validation Function ---

def run_cross_validation(
        lang: str, 
        tokenizer: AutoTokenizer, 
        grid: Dict[str, List[int]], 
        device: torch.device
        ) -> Dict[str, int]:
    """
    Performs a grid search over LoRA hyperparameters on a single language split 
    and returns the best configuration based on micro_f1.
    """
    print(f"\n=== Starting LoRA Cross-Validation on {lang.upper()} dataset ===")
    
    # Load and tokenize CV data once
    train_df, test_df = load_data(lang)
    train_dataset = Dataset.from_pandas(train_df.reset_index(drop=True))
    test_df_subset = test_df.head(MAX_SAMPLES_EVAL).reset_index(drop=True)
    
    num_proc = os.cpu_count() or 1
    tokenized_train = train_dataset.map(
        preprocess_function, 
        batched=True, 
        remove_columns=train_dataset.column_names, 
        num_proc=num_proc
        )
    
    best_f1 = -1.0
    best_params = {}
    
    for r_val in grid["r"]:
        for alpha_val in grid["lora_alpha"]:
            print(f"\n--- Testing R={r_val}, Alpha={alpha_val} ---")
            
            cv_lora_config = LoraConfig(
                r=r_val, lora_alpha=alpha_val, target_modules=["q", "v"],
                lora_dropout=0.05, bias="none", task_type="SEQ_2_SEQ_LM"
            )
            cv_model = get_peft_model(
                AutoModelForSeq2SeqLM.from_pretrained(
                    MODEL_NAME, config=BASE_MODEL_CONFIG
                    ).to(device), 
                    cv_lora_config
                    )

            cv_args = TrainingArguments(
                output_dir=CV_RESULTS_DIR, per_device_train_batch_size=BATCH_SIZE,
                per_device_eval_batch_size=BATCH_SIZE, learning_rate=LEARNING_RATE,
                num_train_epochs=CV_EPOCHS, evaluation_strategy="no", save_strategy="no",
                logging_dir="./cv_logs", report_to="none", seed=RANDOM_STATE,
            )
            data_collator = DataCollatorForSeq2Seq(tokenizer, model=cv_model)
            
            trainer = Trainer(
                model=cv_model, args=cv_args, train_dataset=tokenized_train,
                tokenizer=tokenizer, data_collator=data_collator
            )
            
            trainer.train()
            temp_save_dir = os.path.join(CV_WEIGHTS_DIR, f"lora_{r_val}_{alpha_val}")
            cv_model.save_pretrained(temp_save_dir)
            
            eval_base = AutoModelForSeq2SeqLM.from_pretrained(
                MODEL_NAME, 
                config=BASE_MODEL_CONFIG
                ).to(device)
            eval_model = PeftModel.from_pretrained(eval_base, temp_save_dir)
            
            preds, golds = predict_emotions_llm(
                eval_model, 
                tokenizer, 
                test_df_subset, lang)
            metrics = compute_metrics_numeric(preds, golds)
            
            micro_f1 = metrics["micro_f1"]
            print(f"Micro F1: {micro_f1:.4f}")

            if micro_f1 > best_f1:
                best_f1 = micro_f1
                best_params = {"r": r_val, "lora_alpha": alpha_val}
            
            del cv_model, trainer, eval_base, eval_model
            torch.cuda.empty_cache()
            gc.collect()
            
    print("==============================================")
    print(f"Best LoRA Hyperparameters (Micro F1={best_f1:.4f}):")
    print(best_params)
    print("==============================================")
    
    return best_params

# --- Main Orchestration Function ---

def main():
    os.makedirs(SAVE_ROOT_DIR, exist_ok=True)
    os.makedirs(TRAIN_ARGS_ROOT_DIR, exist_ok=True)
    os.makedirs(CV_RESULTS_DIR, exist_ok=True)
    os.makedirs(CV_WEIGHTS_DIR, exist_ok=True)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    best_lora_params = run_cross_validation(CV_LANG, tokenizer, LORA_GRID, device)
    
    final_lora_config = LoraConfig(
        r=best_lora_params.get("r", 16),
        lora_alpha=best_lora_params.get("lora_alpha", 32),
        target_modules=["q", "v"],
        lora_dropout=0.05,
        bias="none",
        task_type="SEQ_2_SEQ_LM"
    )
    
    all_results: List[Dict[str, Any]] = []
    
    languages_to_train = LANGUAGES + ["multilingual"]
    
    for lang in languages_to_train:
        train_lang = lang
        
        train_df, test_df = load_data(train_lang)
        
        final_model = get_peft_model(
            AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, config=BASE_MODEL_CONFIG).to(device),
            final_lora_config
        )
        
        metrics = run_peft_training_and_evaluation(
            train_lang, 
            final_model, 
            tokenizer, 
            train_df, 
            test_df, 
            epochs=FINAL_EPOCHS,
            is_multilingual=(lang == "multilingual")
        )
        all_results.append(metrics)
        
        del final_model
        torch.cuda.empty_cache()
        gc.collect()

    results_df = pd.DataFrame(all_results)
    results_df.to_csv(RESULTS_FILE, index=False)
    print(f"\nFinal combined metrics saved to {RESULTS_FILE}")
    print("\n--- Summary of All Final Results ---")
    print(results_df.to_markdown(index=False))

if __name__ == "__main__":
    main()