In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os
os.chdir('/content/drive/MyDrive/IISc/DL')

In [3]:
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

In [4]:
os.environ["KERAS_BACKEND"] = "jax"  # Or "tensorflow" or "torch".
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00"

In [5]:
!pip install -q -U keras-hub
!pip install -q -U keras

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/876.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━[0m [32m471.0/876.5 kB[0m [31m13.9 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m876.5/876.5 kB[0m [31m17.5 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
keras-nlp 0.18.1 requires keras-hub==0.18.1, but you have keras-hub 0.21.1 which is incompatible.[0m[31m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m26.3 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
keras-nlp 0.18.1 requires keras

In [6]:
import keras
import keras_hub
import requests
from typing import List, Dict, Tuple
import time
import json
import glob
import tqdm
import pandas as pd

In [7]:
gemma_lm = keras_hub.models.Gemma3CausalLM.from_preset("gemma3_instruct_4b_text")

Downloading from https://www.kaggle.com/api/v1/models/keras/gemma3/keras/gemma3_instruct_4b_text/3/download/config.json...


100%|██████████| 968/968 [00:00<00:00, 1.87MB/s]


Downloading from https://www.kaggle.com/api/v1/models/keras/gemma3/keras/gemma3_instruct_4b_text/3/download/task.json...


100%|██████████| 3.23k/3.23k [00:00<00:00, 5.39MB/s]


Downloading from https://www.kaggle.com/api/v1/models/keras/gemma3/keras/gemma3_instruct_4b_text/3/download/assets/tokenizer/vocabulary.spm...


100%|██████████| 4.47M/4.47M [00:00<00:00, 9.93MB/s]


Downloading from https://www.kaggle.com/api/v1/models/keras/gemma3/keras/gemma3_instruct_4b_text/3/download/model.weights.h5...


100%|██████████| 7.23G/7.23G [02:18<00:00, 56.1MB/s]


In [8]:
def load_wikidata_cache(cache_path):
    if os.path.exists(cache_path):
        try:
            df = pd.read_csv(cache_path)
            cache = {(row['entity'], row['target_lang']): row['translation'] for _, row in df.iterrows()}
            return cache
        except Exception:
            return {}
    return {}

In [9]:
def save_wikidata_cache(cache, cache_path):
    # cache: dict[(entity, target_lang)] -> translation
    rows = [
        {'entity': k[0], 'target_lang': k[1], 'translation': v}
        for k, v in cache.items()
    ]
    df = pd.DataFrame(rows)
    df.to_csv(cache_path, index=False)

In [10]:
def query_wikidata(entity: str, target_lang: str = "fr", cache=None) -> str:
    if cache is not None and (entity, target_lang) in cache:
        return cache[(entity, target_lang)]
    url = "https://www.wikidata.org/w/api.php"
    params = {
        "action": "wbsearchentities",
        "search": entity,
        "language": "en",
        "format": "json"
    }

    response = requests.get(url, params=params)
    results = response.json().get("search", [])

    if not results:
        translation = entity
    else:
        entity_id = results[0]["id"]
        label_url = f"https://www.wikidata.org/wiki/Special:EntityData/{entity_id}.json"
        label_resp = requests.get(label_url).json()
        try:
            labels = label_resp["entities"][entity_id]["labels"]
            translation = labels[target_lang]["value"] if target_lang in labels else entity
        except Exception:
            translation = entity
    if cache is not None:
        cache[(entity, target_lang)] = translation
    return translation

In [11]:
def extract_entities(text: str) -> List[Dict[str, str]]:
    # Prompt for entity extraction
    prompt_template = '''<start_of_turn>user
    Extract all named entities from the following text.
    For each entity, output a JSON object with keys: text, type (PER, LOC, ORG, MISC), and score (confidence 0-1).
    Output a JSON array.
    Do not include ```json or ``` in the output.
    Text: {text} <end_of_turn>
    <start_of_turn>model
    '''
    response = gemma_lm.generate(prompt_template.format(text=text), max_length=500)

    start_tag = "<start_of_turn>model"
    end_tag = "<end_of_turn>"

    start_index = response.find(start_tag)
    end_index = response.find(end_tag, start_index + len(start_tag))

    if start_index != -1 and end_index != -1:
        extracted_text = response[start_index + len(start_tag):end_index]
    else:
        extracted_text = response.strip()
    extracted = extracted_text.strip()

    try:
        entities = json.loads(extracted)
    except Exception:
        entities = []

    return entities

In [12]:
def enrich_entities(entities: List[Dict[str, str]], target_lang: str, cache=None) -> List[Tuple[str, str]]:
    enriched = []
    for ent in entities:
        translated = query_wikidata(ent["text"], target_lang, cache)
        enriched.append((ent["text"], translated))
    return enriched

In [13]:
def create_translation_prompt(text: str, enriched_entities: List[Tuple[str, str]], target_lang: str) -> str:
    entity_list = "\n".join([f"{orig} → {trans}" for orig, trans in enriched_entities if orig != trans])

    prompt_template = '''<start_of_turn>user
    Translate the following sentence to {target_lang}.
    Use the following known entity translations:
    {entity_list}

    Text: {text}
    Only output the translated text.
    Do not include any additional text or explanations.<end_of_turn>
    <start_of_turn>model'''

    prompt = prompt_template.format(text=text, target_lang=target_lang, entity_list=entity_list)
    return prompt

In [14]:
def translate_with_gemma(prompt: str) -> str:
    response = gemma_lm.generate(prompt,  max_length=500)

    start_tag = "<start_of_turn>model"
    end_tag = "<end_of_turn>"

    start_index = response.find(start_tag)
    end_index = response.find(end_tag, start_index + len(start_tag))

    if start_index != -1 and end_index != -1:
        extracted_text = response[start_index + len(start_tag):end_index]
    else:
        extracted_text = response.strip()
    extracted = extracted_text.strip()
    return extracted

In [15]:
def get_language_name(short_code):
    lang_map = {
        'ar': 'Arabic',
        'zh': 'Chinese (Traditional)',
        'fr': 'French',
        'de': 'German',
        'it': 'Italian',
        'ja': 'Japanese',
        'ko': 'Korean',
        'es': 'Spanish',
        'th': 'Thai',
        'tr': 'Turkish',
        'en': 'English',
    }
    return lang_map.get(short_code, short_code)

In [16]:
input_data_folder = "./data/references/validation/"
jsonl_files = glob.glob(f"{input_data_folder}/*.jsonl")
model_name = "gemma3_instruct_4b_text_NER_wikidata"
output_prediction_dir = os.path.join("data/predictions", model_name, "validation")
os.makedirs(output_prediction_dir, exist_ok=True)

wikidata_cache_path = os.path.join("./data", "wikidata_cache.csv")
wikidata_cache = load_wikidata_cache(wikidata_cache_path)

In [17]:
import datetime

log_path = os.path.join(output_prediction_dir, "run.log")
logf = open(log_path, "a", encoding="utf-8")  # Changed to append mode

def log(message: str):
    timestamp = datetime.datetime.now().isoformat()
    logf.write(f"[{timestamp}] {message}\n")


In [18]:
for file_path in jsonl_files:
    filename = os.path.basename(file_path)
    outfile_path = os.path.join(output_prediction_dir, filename)
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line))
    pbar = tqdm.tqdm(total=len(data))
    results = []
    for idx, record in enumerate(data, 1):
        id = record['id']
        source = record['source']
        source_locale = record['source_locale']
        source_language = get_language_name(source_locale)
        target_locale = record['target_locale']
        target_language = get_language_name(target_locale)

        log(f"\nProcessing ID: {id} | Text : {source} | Source : {source_language} | Target: {target_language}\n")
        # --- Entity-aware translation ---
        entities = extract_entities(source)

        log("\n🔎 Named Entities:\n")
        for e in entities:
            log(f"- {e['text']} ({e['type']})\n")

        enriched = enrich_entities(entities, target_locale[:2], wikidata_cache)  # pass cache

        logf.write("\n🌐 Wikidata Enriched Entities:\n")
        for orig, trans in enriched:
            log(f"- {orig} → {trans}\n")

        prompt = create_translation_prompt(source, enriched, target_language)

        log(f"\n📝 Prompt Sent to Gemma:\n{prompt}\n")

        model_translation = translate_with_gemma(prompt).strip()



        log(f"\n🗣️ Final Translated Output:\n{model_translation}\n")

        results.append({
            "id": id,
            "source_language": source_language,
            "target_language": target_language,
            "text": source,
            "prediction": model_translation,
        })
        pbar.update(1)
        if idx % 10 == 0 or idx == len(data):
            with open(outfile_path, 'w', encoding='utf-8') as f:
                for res in results:
                    f.write(json.dumps(res, ensure_ascii=False) + '\n')
            # Periodically persist cache
            save_wikidata_cache(wikidata_cache, wikidata_cache_path)
    log(f"Translations saved to {outfile_path}\n")
    pbar.close()
save_wikidata_cache(wikidata_cache, wikidata_cache_path)
logf.flush()
logf.close()

100%|██████████| 732/732 [1:00:27<00:00,  4.96s/it]
100%|██████████| 722/722 [59:48<00:00,  4.97s/it]
100%|██████████| 731/731 [59:21<00:00,  4.87s/it]
100%|██████████| 739/739 [1:01:24<00:00,  4.99s/it]


In [1]:
from framework import download_comet_model
comet_model = download_comet_model()

  from .autonotebook import tqdm as notebook_tqdm
  from pkg_resources import DistributionNotFound, get_distribution
2025-06-18 13:55:52.763388: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-06-18 13:55:52.776313: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750254952.792652    6024 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750254952.797736    6024 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1750254952.811

In [2]:
import os
import glob
import json
from framework import calculate_comet_scores, calculate_meta_score

model_name = "gemma3_instruct_4b_text"
output_prediction_dir = os.path.join("data/predictions", model_name, "validation")
os.makedirs(output_prediction_dir, exist_ok=True)

input_data_folder = "data/references/validation"
jsonl_files = glob.glob(f"{input_data_folder}/*.jsonl")

def calculate_scores(template_id):
    scores_dir = os.path.join(output_prediction_dir, template_id, "scores")

    if not os.path.exists(scores_dir):
        os.makedirs(scores_dir, exist_ok=True)

    for file_path in jsonl_files:
        references_path = file_path
        filename = os.path.basename(file_path)
        predictions_path = os.path.join(output_prediction_dir, template_id, filename)

        comet_score = calculate_comet_scores(
            comet_model, 
            references_path, 
            predictions_path
        )

        correct_instances, total_instances, meta_score = calculate_meta_score(
            references_path,
            predictions_path)

        evaluation_results = {
            "correct_instances": correct_instances,
            "total_instances": total_instances,
            "comet_score": comet_score,
            "meta_score": meta_score
        }

        evaluation_output_path = os.path.join(scores_dir, f"{os.path.splitext(filename)[0]}.json")
        with open(evaluation_output_path, 'w', encoding='utf-8') as json_file:
            json.dump(evaluation_results, json_file, ensure_ascii=False, indent=4)

In [3]:
calculate_scores("rag-wikidata")

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.


All references have a corresponding prediction
Created 1177 instances


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA RTX A2000 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting DataLoader 0: 100%|██████████| 37/37 [00:15<00:00,  2.35it/s]
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.


Average COMET score: 91.33
Loaded 722 instances.
Loaded 722 predictions.
All references have a corresponding prediction
Created 1260 instances


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting DataLoader 0: 100%|██████████| 40/40 [00:17<00:00,  2.23it/s]
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.


Average COMET score: 90.97
Loaded 731 instances.
Loaded 731 predictions.
All references have a corresponding prediction
Created 1229 instances


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting DataLoader 0: 100%|██████████| 39/39 [00:17<00:00,  2.20it/s]
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.


Average COMET score: 92.97
Loaded 739 instances.
Loaded 739 predictions.
All references have a corresponding prediction
Created 1316 instances


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting DataLoader 0: 100%|██████████| 42/42 [00:20<00:00,  2.06it/s]
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.


Average COMET score: 90.00
Loaded 724 instances.
Loaded 724 predictions.
All references have a corresponding prediction
Created 1268 instances


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting DataLoader 0: 100%|██████████| 40/40 [00:18<00:00,  2.15it/s]
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.


Average COMET score: 92.23
Loaded 730 instances.
Loaded 730 predictions.
All references have a corresponding prediction
Created 1409 instances


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting DataLoader 0: 100%|██████████| 45/45 [00:22<00:00,  1.99it/s]
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.


Average COMET score: 92.88
Loaded 723 instances.
Loaded 723 predictions.
All references have a corresponding prediction
Created 1660 instances


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting DataLoader 0: 100%|██████████| 52/52 [00:23<00:00,  2.20it/s]
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.


Average COMET score: 91.98
Loaded 745 instances.
Loaded 745 predictions.
All references have a corresponding prediction
Created 1654 instances


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting DataLoader 0: 100%|██████████| 52/52 [00:38<00:00,  1.37it/s]
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.


Average COMET score: 87.19
Loaded 710 instances.
Loaded 710 predictions.
All references have a corresponding prediction
Created 1260 instances


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting DataLoader 0: 100%|██████████| 40/40 [00:16<00:00,  2.38it/s]
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.


Average COMET score: 91.41
Loaded 732 instances.
Loaded 732 predictions.
All references have a corresponding prediction
Created 1544 instances


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting DataLoader 0: 100%|██████████| 49/49 [00:21<00:00,  2.33it/s]


Average COMET score: 90.22
Loaded 722 instances.
Loaded 722 predictions.
