In [None]:
import requests
import pandas as pd
from tqdm import tqdm
import sacrebleu
from datasets import Dataset, load_from_disk
from sacrebleu.metrics import CHRF
from datetime import datetime
import json

############################################################################################################

VLLM_API_URL = "http://localhost:8000/generate"
MAX_LEN = 512
val_dataset_path = "data/training_dataset/dataset_val_300.jsonl"
flore_dataset_path = "data/fake_targets/flores_devtest_arrow"
current_time = datetime.now()
formatted_time = current_time.strftime('%m_%d_%H_%M')
eval_output_path = val_dataset_path.split("/")[-1].replace(".jsonl", f"_{formatted_time}_eval_from_vLLM.jsonl")
sample_num = None  # Number of samples to evaluate otherwise set to None

src_lng = "English"
src_lng_abr = "sentence_eng_Latn"

tgt_lng = "Luxembourgish"
tgt_lng_abr = "sentence_ltz_Latn"

############################################################################################################

# Load dataset
if val_dataset_path.endswith(".jsonl"):
    dataset = Dataset.from_json(val_dataset_path)
else:
    dataset = load_from_disk(val_dataset_path)

if sample_num:
    val_dataset = dataset.filter(lambda x: x["split"] == "val").select(range(sample_num))
else:
    val_dataset = dataset.filter(lambda x: x["split"] == "val")

val_dataset = val_dataset.rename_columns({
    "input": "Luxembourgish",
    "translated_text": "English",
})

if sample_num:
    val_flores_dataset = (
        load_from_disk(flore_dataset_path)
        .rename_columns({tgt_lng_abr: tgt_lng, src_lng_abr: src_lng})
        .select([i for i in range(10)])
    )
else:
    val_flores_dataset = load_from_disk(flore_dataset_path).rename_columns({tgt_lng_abr: tgt_lng, src_lng_abr: src_lng})


def call_vllm(prompt):
    response = requests.post(VLLM_API_URL, json={"prompt": prompt, "max_tokens": MAX_LEN * 2, "temperature": 1.0})
    if response.status_code == 200:
        return response.json()["generated_text"]
    return ""


def compute_jaccard(prediction: str, reference: str) -> float:
    pred_set = set(prediction.split())
    ref_set = set(reference.split())
    if not pred_set and not ref_set:
        return 1.0
    return len(pred_set & ref_set) / len(pred_set | ref_set)


def create_prompt(sample, src_lng, tgt_lng):
    system_message = "You are a helpful AI assistant for translation."
    input_text = sample[src_lng.capitalize()].strip()
    full_prompt = f"{system_message}\n\nTranslate the {src_lng} input text into {tgt_lng}.\n\n{input_text}"
    return full_prompt


def generate_dataset_responses(dataset, src_lng, tgt_lng):
    df_results = pd.DataFrame()
    for sample in tqdm(dataset, desc="Generating responses"):
        input_prompt = create_prompt(sample, src_lng, tgt_lng)
        llm_response = call_vllm(input_prompt).strip()
        ground_truth = sample.get(tgt_lng, "")
        index_unique = sample.get("index_unique", "")
        
        spbleu_score = sacrebleu.corpus_bleu([llm_response], [[ground_truth]], tokenize="flores200").score
        chrf_metric = CHRF(word_order=3)
        charf_score = chrf_metric.sentence_score(llm_response, [ground_truth]).score
        jaccard_score = compute_jaccard(llm_response, ground_truth)

        result = {
            "LLM_Input": input_prompt,
            "LLM_Output": llm_response,
            "Ground_Truth": ground_truth,
            "index_unique": index_unique,
            "SPBLEU_Score": spbleu_score,
            "CharF++_Score": charf_score,
            "Jaccard_Score": jaccard_score,
        }
        updated_dataframe = pd.DataFrame([result])
        updated_dataframe.to_json(eval_output_path, orient="records", lines=True, mode="a")
        df_results = pd.concat([df_results, updated_dataframe], axis=0)

    print(f"Average SPBLEU Score: {df_results['SPBLEU_Score'].mean():.2f}")
    print(f"Average CharF++ Score: {df_results['CharF++_Score'].mean():.2f}")
    print(f"Average Jaccard Score: {df_results['Jaccard_Score'].mean():.2f}")
    return df_results

print("Validation RTL Results")
df_RTL_results = generate_dataset_responses(val_dataset, src_lng, tgt_lng)
df_RTL_results["Dataset"] = "RTL"

print("FLORES 200 Results")
df_flores_results = generate_dataset_responses(val_flores_dataset, src_lng, tgt_lng)
df_flores_results["Dataset"] = "FLORES"

df_results = pd.concat([df_RTL_results, df_flores_results], axis=0)
print(f"Results saved to {eval_output_path}")
