In [None]:
# 1 Setup
import sys
import os
sys.path.append(os.path.abspath("."))
from model_3b import generate_translation

import comet
from tqdm import tqdm
import pandas as pd
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

import torch
import torch.nn.functional as F
import math

PyTorch version 2.6.0+cu126 available.


In [None]:
# 2 Data
dataset = load_dataset("Muennighoff/flores200", "all", revision="refs/pr/7", trust_remote_code=True)
dev_set = dataset["dev"]
lang_pairs = {
    "zho-eng": dev_set.filter(lambda x: x["sentence_zho_Hans"] and x["sentence_eng_Latn"]),
    "eng-zho": dev_set.filter(lambda x: x["sentence_eng_Latn"] and x["sentence_zho_Hans"]),
    "fra-eng": dev_set.filter(lambda x: x["sentence_fra_Latn"] and x["sentence_eng_Latn"]),
    "eng-fra": dev_set.filter(lambda x: x["sentence_eng_Latn"] and x["sentence_fra_Latn"]),
    "nld-eng": dev_set.filter(lambda x: x["sentence_nld_Latn"] and x["sentence_eng_Latn"]),
    "eng-nld": dev_set.filter(lambda x: x["sentence_eng_Latn"] and x["sentence_nld_Latn"]),
    "khk-eng": dev_set.filter(lambda x: x["sentence_khk_Cyrl"] and x["sentence_eng_Latn"]),
    "eng-khk": dev_set.filter(lambda x: x["sentence_eng_Latn"] and x["sentence_khk_Cyrl"]),
}
print(f"Number of language pairs: {len(lang_pairs)}")
print(f"Number of examples in each language pair: {[len(lang_pairs[lp]) for lp in lang_pairs]}")

Number of language pairs: 8
Number of examples in each language pair: [997, 997, 997, 997, 997, 997, 997, 997]


In [None]:
# 3 BLEU and METEOR
import sacrebleu

def compute_bleu(predictions, references):
    if isinstance(predictions, str):
        predictions = [predictions]
    if isinstance(references[0], str):
        references = [[ref] for ref in references]

    scores = []
    for pred, ref in zip(predictions, references):
        score = sacrebleu.sentence_bleu(pred, ref).score
        scores.append(score)
    return scores

from nltk.translate.meteor_score import meteor_score
from nltk.tokenize import word_tokenize
import nltk

nltk.download("punkt")
nltk.download("punkt_tab")
# nltk.download('wordnet')
# nltk.download('omw-1.4')

def compute_meteor(predictions, references):
    if isinstance(predictions, str):
        predictions = [predictions]
    if isinstance(references, str):
        references = [references]

    scores = []
    for pred, ref in zip(predictions, references):
        score = meteor_score([word_tokenize(ref)], word_tokenize(pred))
        scores.append(score)
    return scores

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\gerri\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\gerri\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


In [None]:
# 4 COMET
from comet import download_model, load_from_checkpoint

# Reference-based COMET
comet_ref_model_path = download_model("Unbabel/wmt22-comet-da")
comet_ref_model = load_from_checkpoint(comet_ref_model_path)

# Reference-free COMET
cometkiwi_model_path = download_model("Unbabel/wmt22-cometkiwi-da")
cometkiwi_model = load_from_checkpoint(cometkiwi_model_path)


# Safety check
if "comet_ref_model" not in globals():
    comet_ref_model_path = download_model("Unbabel/wmt22-comet-da")
    comet_ref_model = load_from_checkpoint(comet_ref_model_path)

if "cometkiwi_model" not in globals():
    cometkiwi_model_path = download_model("Unbabel/wmt22-cometkiwi-da")
    cometkiwi_model = load_from_checkpoint(cometkiwi_model_path)

# Compute COMET scores
def compute_comet_ref(srcs, mts, refs):
    try:
        data = [{"src": s, "mt": m, "ref": r} for s, m, r in zip(srcs, mts, refs)]
        score = comet_ref_model.predict(data, gpus=1 if torch.cuda.is_available() else 0)
        return score.scores
    except Exception as e:
        print(f"[COMET-REF ERROR] {e}")
        return [float("nan")] * len(srcs)

def compute_cometkiwi(srcs, mts):
    try:
        data = [{"src": s, "mt": m} for s, m in zip(srcs, mts)]
        score = cometkiwi_model.predict(data, gpus=1 if torch.cuda.is_available() else 0)
        return score.scores
    except Exception as e:
        print(f"[COMET-KIWI ERROR] {e}")
        return [float("nan")] * len(srcs)

Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

Lightning automatically upgraded your loaded checkpoint from v1.8.3.post1 to v2.5.1. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint C:\Users\gerri\.cache\huggingface\hub\models--Unbabel--wmt22-comet-da\snapshots\2760a223ac957f30acfb18c8aa649b01cf1d75f2\checkpoints\model.ckpt`
Encoder model frozen.
C:\Users\gerri\AppData\Roaming\Python\Python312\site-packages\pytorch_lightning\core\saving.py:195: Found keys that are not in the model state dict but in the checkpoint: ['encoder.model.embeddings.position_ids']


Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

Lightning automatically upgraded your loaded checkpoint from v1.8.2 to v2.5.1. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint C:\Users\gerri\.cache\huggingface\hub\models--Unbabel--wmt22-cometkiwi-da\snapshots\1ad785194e391eebc6c53e2d0776cada8f83179a\checkpoints\model.ckpt`
Encoder model frozen.


In [None]:
# 5 Results + getting translations
def get_results_batched(examples, source_field, target_field, prompt_template, direction, results_list):
    strategies = ["greedy"]

    for strategy in strategies:
        print(f"\n[Strategy: {strategy}]")

        prompts = []
        sources = []
        references = []

        # Generate prompts and collect source/reference
        for ex in examples:
            source = ex[source_field]
            reference = ex[target_field]
            prompt = prompt_template.format(source=source)

            prompts.append(prompt)
            sources.append(source)
            references.append(reference)

        # Generate translations in batch + log probs + perplexities
        translations = []
        log_probs = []
        perplexities = []
        for prompt in tqdm(prompts, desc=f"Translating ({strategy})"):
            try:
                translation, log_prob, ppl = generate_translation(prompt, strategy)
                translations.append(translation)
                log_probs.append(log_prob)
                perplexities.append(ppl)
            except Exception as e:
                print(f"[ERROR] Strategy {strategy}: {e}")
                translations.append("")
                log_probs.append(float("nan"))
                perplexities.append(float("nan"))

        # Compute BLEU and METEOR in batch
        bleu_scores = compute_bleu(translations, references)
        meteor_scores = compute_meteor(translations, references)

        # Compute COMET scores (aligned)
        valid_indices = [i for i, t in enumerate(translations) if t.strip()]
        comet_refs = [float("nan")] * len(translations)
        comet_wmt = [float("nan")] * len(translations)

        try:
            valid_sources = [sources[i] for i in valid_indices]
            valid_refs = [references[i] for i in valid_indices]
            valid_trans = [translations[i] for i in valid_indices]

            comet_ref_scores = compute_comet_ref(valid_sources, valid_trans, valid_refs)
            comet_wmt_scores = compute_cometkiwi(valid_sources, valid_trans)

            for j, idx in enumerate(valid_indices):
                comet_refs[idx] = comet_ref_scores[j]
                comet_wmt[idx] = comet_wmt_scores[j]
        except Exception as e:
            print(f"[COMET ERROR] {e}")

        # Store results
        for i in range(len(translations)):
            results_list.append({
                "source": sources[i],
                "reference": references[i],
                "strategy": strategy,
                "translation": translations[i],
                "total_log_probs": log_probs[i],
                "perplexity": perplexities[i],
                "bleu": bleu_scores[i],
                "meteor": meteor_scores[i],
                "comet_ref": comet_refs[i],
                "comet_wmt22": comet_wmt[i]
            })

In [None]:
# Destination results
output_dir = "csv_results_3b"
os.makedirs(output_dir, exist_ok=True)

In [None]:
# zho to eng
results_to_eng = []
source_field = "sentence_zho_Hans"
target_field = "sentence_eng_Latn"
prompt_zh2en = """Task: Translate the following Chinese text to English.

Chinese text: {source}

English translation:""".strip()

get_results_batched(
    examples=lang_pairs["zho-eng"].select(range(len(lang_pairs["zho-eng"]))),
    source_field=source_field,
    target_field=target_field,
    prompt_template=prompt_zh2en,
    direction="zho-eng",
    results_list=results_to_eng
)


[Strategy: greedy]


Translating (greedy): 100%|██████████| 997/997 [21:27<00:00,  1.29s/it]
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting DataLoader 0: 100%|██████████| 63/63 [05:46<00:00,  5.50s/it]
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: Fa

In [None]:
# Store in CSV
results_to_eng_df = pd.DataFrame(results_to_eng)
results_to_eng_df.to_csv(os.path.join(output_dir, "zho-eng_bloomz-3b_flores200_results.csv"), index=False)
print(results_to_eng_df.head())

                                              source  \
0  周一，斯坦福大学医学院的科学家宣布，他们发明了一种可以将细胞按类型分类的新型诊断工具：一种可...   
1  主要研究人员表示，这可以让低收入国家/地区的患者尽早发现癌症、肺结核、艾滋病和疟疾。在这些国...   
2  当地时间上午 9:30 左右 (UTC 0230)，JAS 39C 鹰狮战斗机撞上跑道并发生...   
3            涉事飞行员是空军中队长迪罗里·帕塔维 (Dilokrit Pattavee)。   
4                           当地媒体报道，一辆机场消防车在响应火警时翻了车。   

                                           reference strategy  \
0  On Monday, scientists from the Stanford Univer...   greedy   
1  Lead researchers say this may bring early dete...   greedy   
2  The JAS 39C Gripen crashed onto a runway at ar...   greedy   
3  The pilot was identified as Squadron Leader Di...   greedy   
4  Local media reports an airport fire vehicle ro...   greedy   

                                         translation  total_log_probs  \
0  Monday, Stanford University's medical scientis...       -22.328125   
1  The main researchers said that this can help l...       -23.593750   
2  Local time today is 9:30 AM (0230 UTC). JAS 39... 

In [None]:
# fra to eng
results_to_eng = []
source_field = "sentence_fra_Latn"
target_field = "sentence_eng_Latn"
prompt_fr2en = """Task: Translate the following French text to English.

French text: {source}

English translation:""".strip()

get_results_batched(
    examples=lang_pairs["fra-eng"].select(range(len(lang_pairs["zho-eng"]))),
    source_field=source_field,
    target_field=target_field,
    prompt_template=prompt_fr2en,
    direction="fra-eng",
    results_list=results_to_eng
)


[Strategy: greedy]


Translating (greedy): 100%|██████████| 997/997 [22:34<00:00,  1.36s/it]
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting DataLoader 0: 100%|██████████| 63/63 [05:57<00:00,  5.68s/it]
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting DataLoader 0: 100%|██████████| 63/63 [03:57<00:00,  3.77s/it]


In [None]:
# Store in CSV
results_to_eng_df = pd.DataFrame(results_to_eng)
results_to_eng_df.to_csv(os.path.join(output_dir, "fra-eng_bloomz-3b_flores200_results.csv"), index=False)
print(results_to_eng_df.head())

                                              source  \
0  Des scientifiques de l’école de médecine de l’...   
1  Selon les chercheurs principaux, cela pourrait...   
2  Le JAS 39C Gripen s’est écrasé sur une piste a...   
3  Le pilote a été identifié comme étant le chef ...   
4  La presse locale a rapporté qu'un véhicule de ...   

                                           reference strategy  \
0  On Monday, scientists from the Stanford Univer...   greedy   
1  Lead researchers say this may bring early dete...   greedy   
2  The JAS 39C Gripen crashed onto a runway at ar...   greedy   
3  The pilot was identified as Squadron Leader Di...   greedy   
4  Local media reports an airport fire vehicle ro...   greedy   

                                         translation  total_log_probs  \
0  French text: Scientists from Stanford Universi...       -22.546875   
1  Researchers say this could help detect early c...       -13.906250   
2  The JAS 39C Gripen crashed onto a runway aroun... 

In [None]:
# nld tp eng
results_to_eng = []
source_field = "sentence_nld_Latn"
target_field = "sentence_eng_Latn"
prompt_nl2en = """Task: Translate the following Dutch text to English.

Dutch text: {source}

English translation:""".strip()

get_results_batched(
    examples=lang_pairs["nld-eng"].select(range(len(lang_pairs["zho-eng"]))),
    source_field=source_field,
    target_field=target_field,
    prompt_template=prompt_nl2en,
    direction="nld-eng",
    results_list=results_to_eng
)


[Strategy: greedy]


Translating (greedy): 100%|██████████| 997/997 [23:47<00:00,  1.43s/it]
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting DataLoader 0: 100%|██████████| 63/63 [05:57<00:00,  5.68s/it]
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: Fa

In [None]:
# Store in CSV
results_to_eng_df = pd.DataFrame(results_to_eng)
results_to_eng_df.to_csv(os.path.join(output_dir, "nld-eng_bloomz-3b_flores200_results.csv"), index=False)
print(results_to_eng_df.head())

                                              source  \
0  Op maandag kondigden wetenschappers van de Sta...   
1  Hoofdonderzoekers zeggen dat dit kan leiden to...   
2  De JAS 39C Gripen stortte rond 09.30 uur lokal...   
3  De piloot werd geïdentificeerd als majoor Dilo...   
4  De lokale media meldt dat er tijdens een actie...   

                                           reference strategy  \
0  On Monday, scientists from the Stanford Univer...   greedy   
1  Lead researchers say this may bring early dete...   greedy   
2  The JAS 39C Gripen crashed onto a runway at ar...   greedy   
3  The pilot was identified as Squadron Leader Di...   greedy   
4  Local media reports an airport fire vehicle ro...   greedy   

                                         translation  total_log_probs  \
0  On Monday, scientists from the Stanford Univer...       -33.750000   
1  Dutch text: Researchers say that this is the f...       -36.718750   
2  Dutch text: The JAS 39C Gripen crashed around ... 

In [None]:
# khk to eng
results_to_eng = []
source_field = "sentence_khk_Cyrl"
target_field = "sentence_eng_Latn"
prompt_kh2en = """Task: Translate the following Mongolian text to English.

Mongolian text: {source}

English translation:""".strip()

get_results_batched(
    examples=lang_pairs["khk-eng"].select(range(len(lang_pairs["zho-eng"]))),
    source_field=source_field,
    target_field=target_field,
    prompt_template=prompt_kh2en,
    direction="khk-eng",
    results_list=results_to_eng
)


[Strategy: greedy]


Translating (greedy): 100%|██████████| 997/997 [25:09<00:00,  1.51s/it]
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting DataLoader 0: 100%|██████████| 63/63 [08:03<00:00,  7.67s/it]
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting DataLoader 0: 100%|██████████| 63/63 [05:53<00:00,  5.62s/it]


In [None]:
# Store in CSV
results_to_eng_df = pd.DataFrame(results_to_eng)
results_to_eng_df.to_csv(os.path.join(output_dir, "khk-eng_bloomz-3b_flores200_results.csv"), index=False)
print(results_to_eng_df.head())

                                              source  \
0  Даваа гарагт Стэнфордын Их Сургуулийн Анагаахы...   
1  Гол судлаачдын зүгээс энэ нь хөхний хорт хавда...   
2  ЖАС 39Си Грипен нь орон нутгийн цагаар өглөөни...   
3  Нисгэгч нь Эскадрилийн аххлагч Дилокрит Паттав...   
4  Нисэх онгоцны буудлын галын машин өнхөрсөн тал...   

                                           reference strategy  \
0  On Monday, scientists from the Stanford Univer...   greedy   
1  Lead researchers say this may bring early dete...   greedy   
2  The JAS 39C Gripen crashed onto a runway at ar...   greedy   
3  The pilot was identified as Squadron Leader Di...   greedy   
4  Local media reports an airport fire vehicle ro...   greedy   

                                         translation  total_log_probs  \
0  Mongolian text: Davaa Gari Gari - The Mongolia...        -36.56250   
1  Mongolian text: The Mongolian word for "good l...        -32.18750   
2  Mongolian text: Mongolian text: 39S.G. Gripen ... 