In [None]:
# Importing necessary libraries

import json
import glob
import demjson3
from langchain_ollama import OllamaLLM
from langchain.chains import LLMChain
from comet.models import download_model, load_from_checkpoint
import os

from entity_extraction import (
    extract_capitalized_phrases,
    extract_after_prepositions,
    extract_quoted_entities,
    extract_hyphenated_entities,
    extract_entities_with_numbers_or_roman,
    validate_entities
)
from framework import extract_entity_translation, fetch_wikidata_label, calculate_comet_scores, calculate_meta_score
from prompt_templates import (
    entity_extraction_prompt,
    entity_rethinking_prompt,
    translation_prompt
)

In [None]:
# Saving translated outputs to JSONL file

language_filepaths = {}

def load_all_jsonl_files_by_language(folder_path):
    lang_data = {}

    for file_path in glob.glob(f"{folder_path}/*.jsonl"):
        file_name = os.path.basename(file_path)
        lang_code = file_name.split("_")[0]
        language_filepaths[lang_code] = os.path.splitext(file_name)[0] 

        if lang_code not in lang_data:
            lang_data[lang_code] = []

        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                lang_data[lang_code].append(json.loads(line))

    return lang_data


def get_language_name(short_code):
    lang_map = {
        'ar': 'Arabic', 'zh': 'Chinese (Traditional)', 'fr': 'French', 'de': 'German',
        'it': 'Italian', 'ja': 'Japanese', 'ko': 'Korean', 'es': 'Spanish',
        'th': 'Thai', 'tr': 'Turkish', 'en': 'English'
    }
    return lang_map.get(short_code, short_code)

In [None]:
jsonl_folder = "data/references/validation"
all_lang_data  = load_all_jsonl_files_by_language(jsonl_folder)

In [None]:
for lang_code, records in all_lang_data.items():
    print(f"Loaded {len(records)} records for {get_language_name(lang_code)} ({lang_code})")

In [None]:
# Defining LangChain chains with prompt templates

# Use Ollama
llm = OllamaLLM(model="mistral")
chain_extract = LLMChain(llm=llm, prompt=entity_extraction_prompt)
chain_rethink = LLMChain(llm=llm, prompt=entity_rethinking_prompt)
chain_translate = LLMChain(llm=llm, prompt=translation_prompt)

results = []

for lang_code, records in all_lang_data.items():
    language = get_language_name(lang_code)

    if lang_code == 'ar':
        continue
    elif lang_code == 'de':
        records = records[240:]

    output_file = f"data/predictions/mistral7b/zero_shot/{language_filepaths[lang_code]}.jsonl"
    results = []

    for record in records:

        source = record['source']
        wikidata_ids = [record['wikidata_id']]
        
        # Extract named entities
        try:
            raw_entities = chain_extract.invoke({"texts": source})
            entity_data = json.loads(raw_entities['text'])
        except Exception:
            try:
                entity_data = demjson3.decode(raw_entities['text'])
            except Exception as e2:
                print(f"Failed to recover batch with demjson3: {e2}")
                continue

        local_entities = set(
            extract_capitalized_phrases(source) +
            extract_after_prepositions(source) +
            extract_quoted_entities(source) +
            extract_hyphenated_entities(source) +
            extract_entities_with_numbers_or_roman(source)
        )

        cleaned_entity_list = []
        
        if isinstance(entity_data, dict):
            cleaned_entity_list.extend(validate_entities(entity_data.get('Entities', []), source))
        elif isinstance(entity_data, list):
            for item in entity_data:
                cleaned_entity_list.extend(validate_entities(item.get('Entities', []), source))

        # Rethink entities
        for entity in local_entities:
            if entity not in cleaned_entity_list:
                correction = chain_rethink.invoke({"sentence": source, "candidate": entity})
                
                try:
                    new_data = json.loads(correction['text'])
                except json.JSONDecodeError as e:
                    try:
                        new_data = demjson3.decode(correction['text'])
                    except Exception as e2:
                        print(f"Failed to recover batch with demjson3: {e2}")
                        continue
                
                if new_data.get('entities'):
                    cleaned_entity_list.extend(new_data['entities'])
                    cleaned_entity_list.append(entity)

        cleaned_entity_list = list(set([x.strip() for x in cleaned_entity_list if x.strip()]))
        
        # Remove duplicate entries
        duplicate_entities = []
        for i in range(len(cleaned_entity_list)):
            for j in range(len(cleaned_entity_list)):
                if i != j and cleaned_entity_list[i] in cleaned_entity_list[j]:
                    duplicate_entities.append(cleaned_entity_list[i])

        final_entity_list = []
        for ent in cleaned_entity_list:
            if ent not in duplicate_entities:
                final_entity_list.append(ent)

        # Translate named entities using Wikidata
        model_entities = []
        for item in final_entity_list:
            ent = extract_entity_translation(item, record['target_locale'])
            if ent['qid']:
                model_entities.append(ent['translated'])
        
        wikidata_entity_names = [fetch_wikidata_label(qid, record['target_locale']) for qid in wikidata_ids]

        # Translate sentence with constraint
        try:
            raw_translated = chain_translate.invoke({
                "sentence": source,
                "language": language,
                "entities": ", ".join(model_entities)
            })
            raw_translated = json.loads(raw_translated['text'])
        except Exception:
            try:
                raw_translated = demjson3.decode(raw_translated['text'])
            except Exception as e2:
                print(f"Failed to recover batch with demjson3: {e2}")
                continue

        results.append({
            "id": record['id'],
            "text": source,
            "source_language": record['source_locale'],
            "target_language": record['target_locale'],
            "prediction": raw_translated['translation'],
        })
        
        with open(output_file, 'w', encoding='utf-8') as f:
            for res in results:
                f.write(json.dumps(res, ensure_ascii=False) + '\n')

In [None]:
# Calculate COMET and M-ETA scores for quality evaluation

comet_model_path = download_model("Unbabel/wmt22-comet-da")
comet_model = load_from_checkpoint(comet_model_path)
model_name = "mistral7b"
output_prediction_dir = os.path.join("data/predictions", model_name)
os.makedirs(output_prediction_dir, exist_ok=True)

input_data_folder = "data/references/validation"
jsonl_files = glob.glob(f"{input_data_folder}/*.jsonl")

def calculate_scores(template_id):
    scores_dir = os.path.join(output_prediction_dir, template_id, "scores")
    
    if not os.path.exists(scores_dir):
        os.makedirs(scores_dir, exist_ok=True)

    for file_path in jsonl_files:
        references_path = file_path
        filename = os.path.basename(file_path)
        predictions_path = os.path.join(output_prediction_dir, template_id, filename)

        comet_score = calculate_comet_scores(
            comet_model, 
            references_path, 
            predictions_path
        )

        correct_instances, total_instances, meta_score = calculate_meta_score(
            references_path,
            predictions_path)

        evaluation_results = {
            "correct_instances": correct_instances,
            "total_instances": total_instances,
            "comet_score": comet_score,
            "meta_score": meta_score
        }

        evaluation_output_path = os.path.join(scores_dir, filename)
        with open(evaluation_output_path, 'w', encoding='utf-8') as json_file:
            json.dump(evaluation_results, json_file, ensure_ascii=False, indent=4)

In [None]:
calculate_scores("zero_shot")