## Test the trained model using the trained model



In [1]:
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer
)
import pandas as pd
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
)
import torch

import pandas as pds
from tqdm import tqdm
import sacrebleu
from datasets import Dataset
from datasets import load_from_disk

MAX_LEN = 512
train_ratio = 0.01
model_name = "/home/snt/llm_models/Llama-3.2-1B-Instruct"
checkpoint = "/home/snt/projects_lujun/mt_luxembourgish/logs/fit_1734873129.3395984/checkpoint-2000"
val_dataset_path = "/home/snt/projects_lujun/mt_luxembourgish/data/processed/dataset_merged_nllb_fake_targets_with_split.jsonl"


# Load dataset
if val_dataset_path.endswith(".jsonl"):
    dataset = Dataset.from_json(val_dataset_path)  # Ensure correct format
else:
    dataset = load_from_disk(val_dataset_path)

# Filter by split
train_dataset = dataset.filter(lambda x: x["split"] == "train")
val_dataset = dataset.filter(lambda x: x["split"] == "val")


# Select subset
train_dataset = train_dataset.select(range(int(len(train_dataset) * train_ratio)))
val_dataset = val_dataset.select(range(int(len(val_dataset) * train_ratio)))  # Avoid out-of-range error

# Rename columns
train_dataset = train_dataset.rename_columns({
    "input": "Luxembourgish",
    "translated_text": "English",
})

val_dataset = val_dataset.rename_columns({
    "input": "Luxembourgish",
    "translated_text": "English",
})


tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Reload model in FP16 and merge it with LoRA weights (was previously converted to 4 bits)
base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float16,
    device_map="cuda:0",
)

model = base_model
# model = PeftModel.from_pretrained(base_model, checkpoint)

# model = AutoModelForCausalLM.from_pretrained(
#     checkpoint,
#     device_map="cuda:0",
# )

# Function to generate from the model
def generate_response(prompt, model):
    encoded_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
    model_inputs = encoded_input.to("cuda")
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=MAX_LEN * 2,
        do_sample=True,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )
    decoded_output = tokenizer.batch_decode(generated_ids)
    return decoded_output[0].replace(prompt, "")


def create_prompt(
    sample, mode="train", src_lng="Luxembourgish", tgt_lng="English", tokenizer=None
):
    """
    Create a prompt using the model's EOS token.

    Args:
        sample (dict): A dictionary containing source and target text.
        mode (str): The mode, either 'train' or 'test'.
        src_lng (str): Source language name.
        tgt_lng (str): Target language name.
        tokenizer: The tokenizer associated with the model (required to fetch EOS token).

    Returns:
        dict: A dictionary with the constructed prompt.
    """
    # Validate the tokenizer input
    if tokenizer is None or tokenizer.eos_token is None:
        raise ValueError("A tokenizer with a defined EOS token is required.")

    # Define the system message template.
    system_message = f"Translate the {src_lng} input text into {tgt_lng}.".upper()
    input_text = sample[src_lng.capitalize()].strip()  # Extract the input text.
    response = (
        sample[tgt_lng.capitalize()].strip() if tgt_lng.capitalize() in sample else ""
    )  # Extract the target text.

    # Get the EOS token from the tokenizer.
    eos_token = tokenizer.eos_token

    # Construct the full prompt.
    full_prompt = (
        "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
        + system_message
        + "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
    )
    full_prompt += (
        input_text + "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
    )
    if mode == "train":
        full_prompt += response + eos_token
    return {"prompt_response": full_prompt}


def generate_dataset_responses(dataset, model, tgt_lng="english", tokenizer=None):
    """Generates prompts and corresponding LLM responses for the "test" split of a dataset,
    and computes the SPBLEU score by comparing the LLM responses to the ground truth."""

    source = []
    predictions = []  # List to store ground truth responses
    targets = []  # List to store LLM generated responses
    index_uniques = []

    for sample in tqdm(dataset, desc="Generating responses"):
        test_prompt = create_prompt(
            sample, mode="test", tgt_lng=tgt_lng, tokenizer=tokenizer
        )[
            "prompt_response"
        ]  # Create the prompt in "test" mode

        llm_response = (
            generate_response(test_prompt, model)
            .replace("<|begin_of_text|>", "")
            .replace("<|eot_id|>", "")
        )

        ground_truth = sample.get(
            "English", ""
        )  # Get the ground truth (adjust field name as needed)

        index_unique = sample.get("index_unique", "")

        # Append the LLM response and ground truth
        predictions.append(llm_response)
        targets.append([ground_truth])  # References should be in list format for SPBLEU
        source.append(sample.get("Luxembourgish"))
        index_uniques.append(index_unique)


    # Convert the results into a DataFrame

    # predictions_cleaned = [
    #     p.replace("<|begin_of_text|>", "").replace("<eot_id|>", "") for p in predictions
    # ]
    # targets_cleaned = [
    #     t[0].replace("<|begin_of_text|>", "").replace("<eot_id|>", "") for t in targets
    # ]

    df_results = pd.DataFrame(
        list(zip(source, predictions, targets, index_uniques)),
        columns=["LLM_Input", "LLM_Output", "Ground_Truth", "index_unique"],
    )

    spbleu_scores = [
        sacrebleu.corpus_bleu([p], [t], tokenize="flores200").score
        for p, t in zip(predictions, targets)
    ]
    df_results["SPBLEU_Score"] = spbleu_scores
    average_spbleu = df_results["SPBLEU_Score"].mean()
    print(f"Average SPBLEU Score: {average_spbleu:.2f}")

    return df_results


pre_finetuned_responses = generate_dataset_responses(
    dataset=val_dataset, model=model, tgt_lng="english", tokenizer=tokenizer
)

  from .autonotebook import tqdm as notebook_tqdm
Generating train split: 207058 examples [00:00, 590664.60 examples/s]
Filter: 100%|██████████| 207058/207058 [00:01<00:00, 162692.18 examples/s]
Filter: 100%|██████████| 207058/207058 [00:01<00:00, 174244.71 examples/s]
Generating responses: 100%|██████████| 36/36 [00:33<00:00,  1.07it/s]


Average SPBLEU Score: 8.14


In [2]:
pre_finetuned_responses

Unnamed: 0,LLM_Input,LLM_Output,Ground_Truth,index_unique,SPBLEU_Score
0,D’Konzept vum Spëtzekandidat ass net onëmstrid...,The concept of the opponent's candidate does n...,[The concept of the frontrunner is not without...,149827,11.75371
1,"Et ginn och Beräicher, wou d’Unioun an d’Staat...",And here is the translation:\n\nThe Luxembourg...,[There are also areas where the Union and the ...,149828,0.549157
2,Da gëtt et déi sougenannte besonnesch Kompeten...,The EU's competencies include coordinating the...,[There are the so-called special competences o...,149829,4.405877
3,"Jiddereen, deen d'Walrecht fir déi Walen huet,...","Jerusalem, the land of the Jews, the people of...","[In Luxembourg, on 9 June, anyone who has the ...",149830,2.884817
4,A verschiddene Länner gouf den Alter fir ze wi...,A Luxembourg has been a member of the European...,"[In some countries, however, the voting age ha...",149831,7.204015
5,D’Wale gi vum 6. bis den 9. Juni. D’Period ass...,"On the 6th June, D’Wale went to school until t...","[The elections will be held from 6 to 9 June, ...",149832,3.44347
6,An de Wochen no der Wal fanne sech déi nei gew...,"Over the weeks, the Luxembourgish deputies and...","[In the weeks following the elections, the new...",149833,5.40988
7,Juli ass déi éischt offiziell Sëtzung vum Parl...,The first official sitting of the Luxembourgis...,[The new Parliament will have to decide what t...,149834,19.76357
8,D'Ekonomie ass an den éischten dräi Méint vum ...,The economy experienced a 0.5 percent increase...,[The economy grew slightly in the first three ...,149835,14.849932
9,"D'Regierung warnt virun Onéierlechen, déi fals...",A warning from the government stating that one...,[The government warns against fraudsters selli...,149836,5.110877


In [3]:
pre_finetuned_responses.to_csv("output_base_model.csv")