### 1. Import necessary libraries

In [None]:
import json
import glob
import demjson3
from langchain_ollama import OllamaLLM
from comet.models import download_model, load_from_checkpoint
import os
import re
import ast
from ftfy import fix_text
from langchain.prompts import PromptTemplate

from entity_extraction import (
    extract_capitalized_phrases,
    extract_after_prepositions,
    extract_quoted_entities,
    extract_hyphenated_entities,
    extract_entities_with_numbers_or_roman,
    validate_entities
)
from framework import extract_entity_translation, calculate_comet_scores, calculate_meta_score

### 2. Create prompt to extract entity

In [None]:
entity_extraction_prompt = PromptTemplate(
    input_variables=["texts"],
    template='''You are a named entity recognition (NER) expert.
 
For each of the following English sentences, extract all named entities (e.g., people, places, organizations, TV series, movies, books).
 
Instructions:
- ONLY extract named entities that appear EXACTLY and VERBATIM in the sentence.
- DO NOT return alternate names, inferred references, or canonical forms.
- DO NOT perform translation, rewriting, or guessing.
- DO NOT infer likely entities or use context to deduce names.
- An entity is valid ONLY if it is an exact substring match found in the sentence.
- If an entity is not present word-for-word in the sentence, DO NOT include it.
- DO NOT return partial entities or reformatted names.
 
Output format:
- Return a single JSON array.
- Each item must be an object with these two fields:
  - "Source": the original sentence.
  - "Entities": a list of string values. Each string must be copied directly from the sentence. Each must be an exact substring of the Source.
 
Rules:
- DO NOT include any reasoning or commentary.
- DO NOT include any Markdown formatting like triple backticks.
- The output MUST be valid JSON that can be parsed by Python's json.loads().
 
Texts:
{texts}
'''
)

### 3. Create prompt to verify and refine locally extracted candidate entities from a sentence

In [None]:
entity_rethinking_prompt = PromptTemplate(
    input_variables=["sentence", "candidate"],
    template="""
You are an expert in Named Entity Recognition (NER). Named entities can be people, places, organizations, TV series, movies, books, etc.
 
Task:
Given a sentence and a candidate phrase, your goal is to identify the **most complete named entity** from the sentence that includes the candidate. This helps verify whether the candidate is a full entity, a partial one, or invalid.
 
Guidelines:
- If the candidate is a **subset of a longer named entity**, return the **full entity** from the sentence as-is.
- If the candidate **fully matches** a named entity in the sentence, return it.
- If the candidate is **not part of any valid named entity** in the sentence, return an empty list.
- Always return the named entity **verbatim**, exactly as it appears in the sentence (including casing, punctuation, etc.).
- Do not add inferred terms or modify the sentence.
- Named entities should not be a category or type.
 
Response format:
Return a valid JSON object with the following keys:
- "sentence": the original input sentence.
- "entities": a list with either the corrected entity string or an empty list if not found.
 
Constraints:
- Output only the JSON. No markdown, no code blocks, no commentary.
 
Input:
Sentence: "{sentence}"
Candidate: "{candidate}"
 
Output:
"""
)

### 4. Create prompt to translate a sentence to respective language

In [None]:
translation_prompt = PromptTemplate(
    input_variables=["sentence", "language", "entities", "exampleSentence", "exampleEntities", "exampleTranslation", "exampleTranslatedEntities"], # type: ignore
    template="""
You are a professional translator with expertise in high-fidelity, fluent translations that preserve named entities.
 
Translate the following English sentence into {language}. The translation MUST meet the following constraints:
 
1. The meaning is preserved **accurately** and the sentence reads naturally to native speakers.
2. All named entities from the list {entities} MUST appear in the translated sentence.
3. Use **natural phrasing and correct grammar** in {language}.
4. Avoid literal word-for-word translation and aim for native-like fluency.
5. Do NOT hallucinate or modify entity names. Only translate using the list provided.
6. Do not include any code fragments such as ``` or ```json in the output.
7. Do not include any additional text or explanations.
8. Ensure all JSON fields are correctly separated by commas. Do not omit commas between items or key-value pairs.
9. Ensure consistent determiners (e.g., "la", "l’") and capitalization for entities across translations.
10. Use provided translated entity names exactly; if multiple entities exist, treat each one distinctly.

Example:
sentence: {exampleSentence}
entities: {exampleEntities}
language: {language}

Expected output:
{{
  "translation": {exampleTranslation},
  "entities": {exampleTranslatedEntities}
}}

Format your response strictly as:

{{
  "translation": "<natural and accurate translated sentence>"
  "entities": ["<translated_entity1>", "<translated_entity2>", ...]
}}

Sentence: "{sentence}"
Entities: {entities}
"""
)

### 5. Create prompt to retry translation

In [None]:
translation_retry_prompt = PromptTemplate(
    input_variables=["sentence", "language", "entities", "exampleSentence", "exampleEntities", "exampleTranslation", "exampleTranslatedEntities"], # type: ignore
    template="""
You are a professional translator. The original translation did not accurately preserve named entities.

Retry translating the English sentence below into {language}, ensuring all named entities in {entities} are:
1. Correctly translated into {language} (not hallucinated or omitted).
2. Placed naturally in the sentence with fluent grammar.
3. Return the response only for {sentence}. Do not include translations for any other sentence.
4. The output must be a valid JSON string that can be parsed by the Python json.loads() function.
5. Do not include any code fragments such as ``` or ```json in the output.
6. Do not include any additional text or explanations.
7. Ensure all JSON fields are correctly separated by commas. Do not omit commas between items or key-value pairs.

Example:
sentence: {exampleSentence}
entities: {exampleEntities}
language: {language}

Expected output:
{{
  "translation": {exampleTranslation},
  "entities": {exampleTranslatedEntities}
}}

Format your response strictly as:

{{
  "translation": "<natural and accurate translated sentence>"
  "entities": ["<translated_entity1>", "<translated_entity2>", ...]
}}

Sentence: "{sentence}"
Entities: {entities}
"""
)

### 6. Examples for one shot

In [None]:
examples = {
    "ar": {
        "exampleSentence": "Where is the Burj Khalifa located?",
        "exampleEntities": ["Burj Khalifa"],
        "exampleTranslation": "أين يقع برج خليفة؟",
        "exampleTranslatedEntities": ["برج خليفة"]
    },
    "zh": {
        "exampleSentence": "When was the Great Wall of China built?",
        "exampleEntities": ["Great Wall of China"],
        "exampleTranslation": "中國長城是什麼時候建造的？",
        "exampleTranslatedEntities": ["中國長城"]
    },
    "fr": {
        "exampleSentence": "Who painted the Mona Lisa?",
        "exampleEntities": ["Mona Lisa"],
        "exampleTranslation": "Qui a peint la Joconde ?",
        "exampleTranslatedEntities": ["la Joconde"]
    },
    "de": {
        "exampleSentence": "Which river flows through Berlin?",
        "exampleEntities": ["Berlin"],
        "exampleTranslation": "Welcher Fluss fließt durch Berlin?",
        "exampleTranslatedEntities": ["Berlin"]
    },
    "it": {
        "exampleSentence": "Where is the Colosseum located?",
        "exampleEntities": ["Colosseum"],
        "exampleTranslation": "Dove si trova il Colosseo?",
        "exampleTranslatedEntities": ["Colosseo"]
    },
    "ja": {
        "exampleSentence": "Which city is Mount Fuji near?",
        "exampleEntities": ["Mount Fuji"],
        "exampleTranslation": "富士山はどの都市の近くにありますか？",
        "exampleTranslatedEntities": ["富士山"]
    },
    "ko": {
        "exampleSentence": "Who is the lead actor in Squid Game?",
        "exampleEntities": ["Squid Game"],
        "exampleTranslation": "오징어 게임의 주연 배우는 누구입니까?",
        "exampleTranslatedEntities": ["오징어 게임"]
    },
    "es": {
        "exampleSentence": "Where was Pablo Picasso born?",
        "exampleEntities": ["Pablo Picasso"],
        "exampleTranslation": "¿Dónde nació Pablo Picasso?",
        "exampleTranslatedEntities": ["Pablo Picasso"]
    },
    "th": {
        "exampleSentence": "Where can you see the Grand Palace in Thailand?",
        "exampleEntities": ["Grand Palace", "Thailand"],
        "exampleTranslation": "พระบรมมหาราชวังตั้งอยู่ที่ไหนในประเทศไทย?",
        "exampleTranslatedEntities": ["พระบรมมหาราชวัง", "ประเทศไทย"]
    },
    "tr": {
        "exampleSentence": "In which city is the Hagia Sophia located?",
        "exampleEntities": ["Hagia Sophia"],
        "exampleTranslation": "Ayasofya hangi şehirde bulunur?",
        "exampleTranslatedEntities": ["Ayasofya"]
    }
}

### 7. Get folder and file path for translation

In [None]:
language_filepaths = {}

def load_all_jsonl_files_by_language(folder_path):
    lang_data = {}

    for file_path in glob.glob(f"{folder_path}/*.jsonl"):
        file_name = os.path.basename(file_path)
        lang_code = file_name.split("_")[0]
        language_filepaths[lang_code] = os.path.splitext(file_name)[0] 

        if lang_code not in lang_data:
            lang_data[lang_code] = []

        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                lang_data[lang_code].append(json.loads(line))

    return lang_data


def get_language_name(short_code):
    lang_map = {
        'ar': 'Arabic', 'zh': 'Chinese (Traditional)', 'fr': 'French', 'de': 'German',
        'it': 'Italian', 'ja': 'Japanese', 'ko': 'Korean', 'es': 'Spanish',
        'th': 'Thai', 'tr': 'Turkish', 'en': 'English'
    }
    return lang_map.get(short_code, short_code)

### 8. Retrieve data from JSON files

In [None]:
jsonl_folder = "data/references/validation"
all_lang_data  = load_all_jsonl_files_by_language(jsonl_folder)

### 9. Verify the loaded files

In [None]:
for lang_code, records in all_lang_data.items():
    print(f"Loaded {len(records)} records for {get_language_name(lang_code)} ({lang_code})")

### 10. Handle comma related errors during parsing

In [None]:
def fix_missing_commas(raw_text):
    fixed = re.sub(r'("\s*)(")', r'\1,\2', raw_text)  # insert missing comma between two quoted fields
    return fixed

### 11. Handle parsing JSON objects

In [None]:
def safe_model_output_parse(raw_output):

    if isinstance(raw_output, dict):
        return raw_output

    if not isinstance(raw_output, str):
        try:
            raw_output = raw_output.decode('utf-8')
        except:
            raw_output = str(raw_output)

    try:
        fixed = fix_text(raw_output.strip())
        
        return json.loads(fixed)
    except Exception:
        try:
            fixed = fix_text(raw_output.strip())
            
            return demjson3.decode(fixed)
        except Exception as e:
            print(f"Failed to recover batch with demjson3: {e}")
            
            return None

### 12. Define LangChain with prompt templates

In [None]:
# Use Ollama
llm = OllamaLLM(model="mistral")
chain_extract = entity_extraction_prompt | llm
chain_rethink = entity_rethinking_prompt | llm
chain_translate = translation_prompt | llm
chain_retry_translate = translation_retry_prompt | llm

### 13. Extract named entities from model

In [None]:
def extract_named_entities_from_model(source):
    try:
        raw_entities = chain_extract.invoke({"texts": source})
        entity_data = json.loads(raw_entities['text'])

        return entity_data
    except Exception:
        try:
            entity_data = demjson3.decode(raw_entities['text'])

            return entity_data
        except Exception as e2:
            print(f"Failed to recover batch with demjson3: {e2}")
            
            return None

### 14. Extract named entities locally

In [None]:
def extract_named_entities_locally(source):
    return set(
            extract_capitalized_phrases(source) +
            extract_after_prepositions(source) +
            extract_quoted_entities(source) +
            extract_hyphenated_entities(source) +
            extract_entities_with_numbers_or_roman(source)
        )

### 15. Refine locally extracted entities

In [None]:
def refine_locally_extracted_entities(source, local_entities, cleaned_entity_list):
    for entity in local_entities:
        if entity not in cleaned_entity_list:
            correction = chain_rethink.invoke({"sentence": source, "candidate": entity})
                
            try:
                new_data = json.loads(correction['text'])
            except json.JSONDecodeError as e:
                try:
                    new_data = demjson3.decode(correction['text'])
                except Exception as e2:
                    print(f"Failed to recover batch with demjson3: {e2}")
                    continue
                
            if new_data.get('entities'):
                cleaned_entity_list.extend(new_data['entities'])
                cleaned_entity_list.append(entity)

    return list(set([x.strip() for x in cleaned_entity_list if x.strip()]))

### 16. Remove duplicate entities (if any)

In [None]:
def remove_duplicate_entities(cleaned_entity_list):
    duplicate_entities = []
    
    for i in range(len(cleaned_entity_list)):
        for j in range(len(cleaned_entity_list)):
            if i != j and cleaned_entity_list[i] in cleaned_entity_list[j]:
                duplicate_entities.append(cleaned_entity_list[i])

    final_entity_list = []
    for ent in cleaned_entity_list:
        if ent not in duplicate_entities:
            final_entity_list.append(ent)

    return final_entity_list

### 17. Entity Translation via Retrieval (RAG Component), also includes finding the best match entity

In [None]:
def translate_named_entities(final_entity_list, lang_code):
    model_entities = []
    
    for item in final_entity_list:
        ent = extract_entity_translation(item, lang_code)
            
        if ent['qid']:
            model_entities.append(ent['translated'])
    
    return model_entities

### 18. Perform translation

In [None]:
def translate_sentence(source, language, model_entities, record):
    retry_cnt = 0
    
    # retry till entities are not part of translated sentence
    while retry_cnt < 5:
        retry_cnt += 1
        
        try:
            if retry_cnt == 0:
                raw_translated = chain_translate.invoke({
                    "sentence": source,
                    "language": language,
                    "entities": ", ".join(model_entities),
                    "exampleSentence": examples[record['target_locale']]['exampleSentence'],
                    "exampleEntities": examples[record['target_locale']]['exampleEntities'],
                    "exampleTranslation": examples[record['target_locale']]['exampleTranslation'],
                    "exampleTranslatedEntities": examples[record['target_locale']]['exampleTranslatedEntities']
                })
            else:
                raw_translated = chain_retry_translate.invoke({
                    "sentence": source,
                    "language": language,
                    "entities": ", ".join(model_entities),
                    "exampleSentence": examples[record['target_locale']]['exampleSentence'],
                    "exampleEntities": examples[record['target_locale']]['exampleEntities'],
                    "exampleTranslation": examples[record['target_locale']]['exampleTranslation'],
                    "exampleTranslatedEntities": examples[record['target_locale']]['exampleTranslatedEntities']
                })
            
            raw_translated = json.loads(raw_translated['text'])
        except Exception:
            
            try:
                cleaned = fix_missing_commas(raw_translated['text'])
                fixed = ast.literal_eval(cleaned.replace("'", '"'))
                raw_translated = safe_model_output_parse(fixed)

                if raw_translated is None:
                    print("Failed to parse model output, retrying...")
                    retry_cnt -= 1
                    
                    continue
            except Exception as e2:
                print(f"Failed to recover batch with demjson3: {e2}")
                
                continue
            
        generated_entities = raw_translated['entities']
        entities_found = True

        for entity in generated_entities:
            if entity not in raw_translated['translation']:
                entities_found = False
                break
            
        for entity in model_entities:
            if entity not in raw_translated['translation']:
                entities_found = False
                break
            
        if entities_found:
            break

    return None

### 19. Begin one - shot + RAG translation using LangChain for each language

In [None]:
for lang_code, records in all_lang_data.items():
    language = get_language_name(lang_code)

    output_file = f"data/predictions/mistral7b/validation/one_shot_rag_wikidata/{language_filepaths[lang_code]}.jsonl"
    results = []
    
    for record in records:

        source = record['source']
        
        # Extract named entities
        entity_data = extract_named_entities_from_model(source)
        
        if (entity_data == None):
            continue

        local_entities = extract_named_entities_locally(source)

        cleaned_entity_list = []
        if isinstance(entity_data, dict):
            cleaned_entity_list.extend(validate_entities(entity_data.get('Entities', []), source))
        elif isinstance(entity_data, list):
            for item in entity_data:
                cleaned_entity_list.extend(validate_entities(item.get('Entities', []), source))

        # Refine entities
        cleaned_entity_list = refine_locally_extracted_entities(source, local_entities, cleaned_entity_list)
        
        # Remove duplicate entries
        final_entity_list = remove_duplicate_entities(cleaned_entity_list)

        # Translate named entities using Wikidata
        model_entities = translate_named_entities(final_entity_list, record['target_locale'])

        # Translate sentence with constraint
        raw_translated = translate_sentence(source, language, model_entities, record)

        if (raw_translated == None):
            continue

        results.append({
            "id": record['id'],
            "text": source,
            "source_language": record['source_locale'],
            "target_language": record['target_locale'],
            "prediction": raw_translated['translation']
        })
        
        # Record the results to a file
        with open(output_file, 'w', encoding='utf-8') as f:
            for res in results:
                f.write(json.dumps(res, ensure_ascii=False) + '\n')

### 20. Define folder and file structure to save M-ETA and COMET scores

In [None]:
comet_model_path = download_model("Unbabel/wmt22-comet-da")
comet_model = load_from_checkpoint(comet_model_path)
model_name = "mistral7b"
output_prediction_dir = os.path.join("data/predictions", model_name)
os.makedirs(output_prediction_dir, exist_ok=True)

input_data_folder = "data/references/validation"
jsonl_files = glob.glob(f"{input_data_folder}/*.jsonl")

### 21. COMET and M-ETA scores calculation

In [None]:
def calculate_scores(template_id):
    scores_dir = os.path.join(output_prediction_dir, template_id, "scores")
    
    if not os.path.exists(scores_dir):
        os.makedirs(scores_dir, exist_ok=True)

    for file_path in jsonl_files:
        references_path = file_path
        filename = os.path.basename(file_path)
        predictions_path = os.path.join(output_prediction_dir, template_id, filename)

        comet_score = calculate_comet_scores(
            comet_model, 
            references_path, 
            predictions_path
        )

        correct_instances, total_instances, meta_score = calculate_meta_score(
            references_path,
            predictions_path)

        evaluation_results = {
            "correct_instances": correct_instances,
            "total_instances": total_instances,
            "comet_score": comet_score,
            "meta_score": meta_score
        }

        new_filename = filename.replace(".jsonl", ".json")
        evaluation_output_path = os.path.join(scores_dir, new_filename)
        with open(evaluation_output_path, 'w', encoding='utf-8') as json_file:
            json.dump(evaluation_results, json_file, ensure_ascii=False, indent=4)

### 22. Calculate COMET and M-ETA scores for quality evaluation

In [None]:
calculate_scores("one_shot_rag_wikidata")