# Import packages

In [None]:
# Standard library
import re
import os
import json
import gc
# Third-party libraries
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt

# Machine learning and NLP
import evaluate
from sentence_transformers import SentenceTransformer, util
from sklearn.metrics.pairwise import cosine_similarity
from transformers import AutoModel, AutoTokenizer, pipeline

# API clients
from openai import OpenAI


# The main scoring section

## Set up for GPT as judge to do the human-like scoring 

In [None]:
with open("/home/NE6131039/Desktop/Confidential_Key.txt", "r") as f:
    api_key = f.read().strip()

client = OpenAI(api_key=api_key)
GPT_MODEL = "gpt-4o"

In [None]:
def build_gpt_score_prompt(prediction, reference, question):
# Prompt construction logic omitted for proprietary reasons
# Contact authors for detailed evaluation methodology
    return f"""
    Builds evaluation prompt for GPT to score TEM image analysis responses.
    
    The prompt instructs GPT to:
    - Act as a TEM domain expert
    - Evaluate prediction against reference answer
    - Consider scientific accuracy, completeness, and technical precision
    - Return numerical score (0.00-1.00) only
    
    Args:
        prediction (str): Model's predicted answer
        reference (str): Ground truth reference answer  
        question (str): Original question asked
        
    Returns:
        str: Formatted evaluation prompt
"""


def gpt_score(prediction, reference, question):
    prompt = build_gpt_score_prompt(prediction, reference, question)
    try:
        response = client.chat.completions.create(
            model=GPT_MODEL,
            messages=[
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": prompt}
            ],
            temperature=0,
        )
        content = response.choices[0].message.content.strip()
        
        match = re.search(r"(\d+(?:\.\d{1,2})?)", content)
        if match:
            score = float(match.group(1))
            score = min(1.0, max(0.0, score))
            return round(score, 2)
        

        print(f"[GPT PARSE WARNING] Could not parse score from: {content}")
        
    except Exception as e:
        print(f"[GPT ERROR] {e}")
    return 0.0

## Load in files & scoring metric

In [None]:
answers = [
    "pretrain_val_predict.csv",
    "finetune_no_curriculum_val_predict.csv",
    # "finetune_curriculum_val_predict.csv",
    "finetune_curriculum_final_val_predict.csv"
]

In [None]:
bleu = evaluate.load("bleu")
meteor = evaluate.load("meteor")
rouge = evaluate.load("rouge")

sbert = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
bertscore = evaluate.load("bertscore")

device = torch.device("cuda:2")
sbert = sbert.to(device)

## all data scoring

In [None]:
# ==== Main Loop ====
for answer in answers:
    df = pd.read_csv(answer)
    
    sample_df = df.groupby('type').sample(n=2500, random_state=42)
    df = sample_df.reset_index(drop=True)
    
    references = df["expected"].astype(str).tolist()
    predictions = df["predicted"].astype(str).tolist()
    questions = df["question"].astype(str).tolist()

    # =======================
    #        GPT_SCORE
    # =======================
    gpt_scores = []
    for pred, ref,que in tqdm(zip(predictions, references,questions), total=len(predictions), desc="GPT Scoring"):
        gpt_scores.append(gpt_score(pred, ref,que))

    df["gpt_score"] = gpt_scores
    
    lexical_scores = []
    bleu_combined_scores = []
    rouge_combined_scores = []
    meteor_scores = []
    # =======================
    #        LEXICAL
    # =======================
    for pred, ref in tqdm(zip(predictions, references), total=len(predictions)):
        try:
            
            # BLEU-1 to BLEU-4 (get all max_order BLEU scores)
            bleu_scores_all = bleu.compute(predictions=[pred], references=[[ref]], max_order=4, smooth=True)
            bleu1 = bleu_scores_all.get("precisions", [0, 0, 0, 0])[0]
            bleu2 = bleu_scores_all.get("precisions", [0, 0, 0, 0])[1]
            bleu3 = bleu_scores_all.get("precisions", [0, 0, 0, 0])[2]
            bleu4 = bleu_scores_all.get("precisions", [0, 0, 0, 0])[3]
            
            bleu_combined  = 0.4*bleu1+0.3*bleu2+0.2*bleu3+0.1*bleu4
            
            #METEOR 
            meteor_score = meteor.compute(predictions=[pred], references=[ref])["meteor"]

            #ROUGE
            rouge_score = rouge.compute(predictions=[pred], references=[ref])
            rouge_1 = rouge_score.get("rouge1", 0.0)
            rouge_2 = rouge_score.get("rouge2", 0.0)
            rouge_l = rouge_score.get("rougeL", 0.0)
            rouge_lsum = rouge_score.get("rougeLsum", 0.0)
            
            rouge_combined = (rouge_1 + rouge_2 + ((rouge_l + rouge_lsum) / 2)) / 3

            score = (
                0.3 * rouge_combined +
                0.2 * bleu_combined +
                0.5 * meteor_score
            )
        except Exception as e:
            print(f"Error on sample: {e}")
            score = 0.0
            
        # Append metrics
        bleu_combined_scores.append(round(min(max(bleu_combined, 0.0), 1.0), 4))
        rouge_combined_scores.append(round(min(max(rouge_combined, 0.0), 1.0), 4))
        meteor_scores.append(round(min(max(meteor_score, 0.0), 1.0), 4))
        lexical_scores.append(round(min(max(score, 0.0), 1.0), 4))
        
    df["bleu_scores"] = bleu_combined_scores
    df["rouge_scores"] = rouge_combined_scores
    df["meteor_scores"] = meteor_scores
    df["lexical_scores"] = lexical_scores
    

    # =======================
    #        SEMANTIC
    # =======================
    bert_scores = []
    sbert_scores = []
    semantic_scores = []

    for _, row in tqdm(df.iterrows(), total=len(df), desc=f"Processing {answer}"):
        pred = str(row["predicted"])
        ref = str(row["expected"])

        # --- BERTScore ---
        try:
            bert_result = bertscore.compute(
                predictions=[pred],
                references=[ref],
                lang="en",
                model_type="microsoft/deberta-xlarge-mnli",
                device="cuda:3"
            )
            bert_f1 = bert_result["f1"][0]
        except Exception as e:
            print(f"BERTScore error: {e}")
            bert_f1 = 0.0

        # --- SBERT cosine similarity ---
        try:
            emb_pred = sbert.encode(pred, convert_to_tensor=True, device="cuda:2")
            emb_ref = sbert.encode(ref, convert_to_tensor=True, device="cuda:2")
            sim = util.cos_sim(emb_pred, emb_ref).item()
        except Exception as e:
            print(f"SBERT error: {e}")
            sim = 0.0

        # Average semantic score
        avg = (bert_f1 + sim) / 2

        # Accumulate
        bert_scores.append(round(min(max(bert_f1, 0.0), 1.0), 4))
        sbert_scores.append(round(min(max(sim, 0.0), 1.0), 4))
        semantic_scores.append(round(min(max(avg, 0.0), 1.0), 4))

        # Optional: cleanup per-row
        del emb_pred, emb_ref
        torch.cuda.empty_cache()
        gc.collect()

    # Save
    df["bert_scores"] = bert_scores
    df["sbert_scores"] = sbert_scores
    df["semantic_scores"] = semantic_scores
   
    # # =======================
    # #     FINAL SCORE
    # # =======================
    lexical_np = np.array(lexical_scores)
    semantic_np = np.array(semantic_scores)

    #average
    final_scores = 0.5 * lexical_np +0.5* semantic_np
    df["final_scores"] = final_scores
    
    # df.drop(columns=["image", "question", "expected", "predicted"], inplace=True)
    df.to_csv(answer.replace(".csv", "_scored.csv"), index=False)