In [1]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" #This should be before import of tokenizer. Note that EncoderTrainer imports tokenizer
import sys
import contextlib
from transformers import BertForTokenClassification, AutoTokenizer
import json
from datetime import datetime 
import csv
from pathlib import Path
from encoder_trainer import EncoderTrainer
from enum import Enum
from dataclasses import dataclass, field
from typing import ClassVar, Tuple, Dict, List, Optional, Union
import traceback
from itertools import product
import re
from collections import defaultdict
from filelock import FileLock
import functools


# Get the project root directory
current_dir = Path.cwd()
project_root = current_dir.parent  # Go up one level from encoder_fine_tuning to cross-lingual-idioms
sys.path.append(str(project_root))

from src.utils import get_data
GPU = "0"
EXECUTED_TESTS_FILE = "executed_tests.json"
DEFAULT_RESULTS_FILE = "unified_results.csv"

os.environ["CUDA_VISIBLE_DEVICES"] = GPU

model_to_train_args = defaultdict(dict)
model_to_train_args.update({"FacebookAI/xlm-roberta-base": {'learning_rate': 2e-05, 'per_device_train_batch_size': 8,
                                                            'per_device_eval_batch_size': 8,  'lr_scheduler_type': 'linear'}})
print(f"{datetime.now()} Start")

  from .autonotebook import tqdm as notebook_tqdm


2025-07-20 19:50:04.485677 Start


In [2]:
def file_mutex(lock_path_func):
    """
    Decorator to ensure file-based mutual exclusion (process and thread safe).
    lock_path_func should return the path to the lock file and accept the same
    arguments as the decorated function.
    """
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            lock_path = lock_path_func(*args, **kwargs)
            with FileLock(lock_path):
                return func(*args, **kwargs)
        return wrapper
    return decorator


In [3]:
class TaskConfig(Enum):
    DODIOM = "dodiom"
    ID10M = "id10m"
    OPEN_MWE = "open_mwe"
    MAGPIE = "magpie"

LANGUAGE_TO_CODE: Dict[str, str] = {
    "english": "EN",
    "spanish": "ES",
    "german": "DE",
    "japanese": "JP",
    "turkish": "TR",
    "chinese": "ZH",
    "french": "FR",
    "polish": "PL",
    "italian": "IT",
    "dutch": "NL",
    "portuguese": "PT"
}
    
def get_language_map() -> Dict[str, str]:
    
    CODE_TO_LANGUAGE: Dict[str, str] = {v: k for k, v in LANGUAGE_TO_CODE.items()}
    
    # Merge into a single bidirectional object
    return {**LANGUAGE_TO_CODE, **CODE_TO_LANGUAGE}

LANGUAGE_MAP = get_language_map()

@file_mutex(lambda directory_path, safe_model_name, decimal_round, output_csv_path: output_csv_path + ".lock")
def process_results_directory_to_csv(directory_path: str, safe_model_name: str, decimal_round: int , output_csv_path: str):
    ordered_test_keys = [
        "EN_id10m",
        "EN_magpie",
        "DE_id10m",
        "IT_id10m",
        "IT_dodiom",
        "ES_id10m",
        "JP_open_mwe",
        "TR_dodiom"
    ]
    
    valid_lang_codes = set(LANGUAGE_TO_CODE.values())
    
    # Regex to parse filenames
    # Example filename: train_lang_zh_src_id10m_FacebookAI_xlm-roberta-base_seed_7.json
    pattern = re.compile(
        r"train_lang_([a-z]{2})_src_([a-zA-Z0-9]+)_([a-zA-Z0-9_\-]+)_seed_(\d+)\.json$",
        re.IGNORECASE
    )

    rows = []
    all_keys = set()

    for filename in os.listdir(directory_path):
        match = pattern.match(filename)
        if not match:
            continue

        train_lang, src, model_name, seed = match.groups()

        if "open_mwe" in safe_model_name:
            #Dirty workaround since mwe_ was incorrectly inserted as a prefix to the model name
            src = "open_mwe"
            model_name = model_name.replace("mwe_", "")
        
        train_lang = train_lang.upper()
        seed = int(seed)

        # Check if train_lang is valid
        if train_lang not in valid_lang_codes:
            print(f"{datetime.now()} Unknown train_lang {train_lang} Skipping")
            continue

        file_path = os.path.join(directory_path, filename)
        with open(file_path, "r", encoding="utf-8") as f:
            json_data = json.load(f)

        # Find target language code in model name, if any
        model_name_lower = model_name.lower()
        model_languages = []
        for lang_name, lang_code in LANGUAGE_TO_CODE.items():
            if lang_name in model_name_lower:
                model_languages.append(lang_code)
                
        # data is a dict like {"EN_id10m": {"eval_f1": 0.329}, ...}
        # Extract keys and their metric values (assume single metric per key)
        row = {
            "train_lang": train_lang,
            "src": src,
            "model_name": model_name,
            "seed": seed,
        }

        # Extract available metrics
        available_metrics = {}
        for key, metric_dict in json_data.items():
            if isinstance(metric_dict, dict):
                val = next(iter(metric_dict.values()))
                if decimal_round>0:
                    available_metrics[key] = round(float(val), decimal_round)
                else:
                    available_metrics[key] = float(val)

        for test_key in ordered_test_keys:
            if test_key in available_metrics:
                row[test_key] = available_metrics[test_key]
            else:
                # Set -1 only if model is specialized (e.g., "turkish") and test_key is for another language
                test_lang = test_key.split("_")[0]  # e.g., EN from "EN_id10m"
                if model_languages and test_lang not in model_languages:
                    row[test_key] = -1
                else:
                    row[test_key] = ""

        rows.append(row)

    csv_columns = ["train_lang", "src", "model_name", "seed"] + ordered_test_keys

    with open(output_csv_path, "w", newline="", encoding="utf-8") as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=csv_columns)
        writer.writeheader()
        for row in rows:
            writer.writerow(row)

    print(f"{datetime.now()} CSV file saved to {output_csv_path}")

def general_process_results_directory_to_csv(directory_path: str, safe_model_name: str, output_csv_path: str):
    name, ext = os.path.splitext(output_csv_path)
    raw_output_csv_path = f"{name}_raw{ext}"
    process_results_directory_to_csv(directory_path=directory_path, safe_model_name=safe_model_name, decimal_round=-1 , output_csv_path=raw_output_csv_path)
    process_results_directory_to_csv(directory_path=directory_path, safe_model_name=safe_model_name, decimal_round=2 , output_csv_path=output_csv_path)

In [4]:
@dataclass(frozen=True)
class LanguageTestSetup:
    """
    Represents a unique (language_code, source) test configuration.

    Instances are created via the `get()` factory method, which ensures
    that the same instance is reused for each (language_code, source) pair.
    """

    language_code: str
    source: str

    # Internal cache to store/reuse instances
    _instances: ClassVar[Dict[Tuple[str, str], "LanguageTestSetup"]] = {}

    def __post_init__(self):
        # Ensure normalized uppercase language codes
        if self.language_code != self.language_code.upper():
            raise ValueError("language_code must be uppercase (e.g., 'EN')")

    @classmethod
    def get(cls, language_code: str, source: str) -> "LanguageTestSetup":
        """
        Factory method to return a cached instance based on input.
        Ensures reuse and avoids duplicates.
        """
        language_code = language_code.upper()
        key = (language_code, source)
        if key not in cls._instances:
            instance = cls(language_code, source)
            cls._instances[key] = instance
        return cls._instances[key]

    def __str__(self):
        return f"{self.language_code}_{self.source}"

DEFAULT_TESTS: List[LanguageTestSetup] = [
    LanguageTestSetup.get("EN", TaskConfig.ID10M.value),
    LanguageTestSetup.get("EN", TaskConfig.MAGPIE.value),
    LanguageTestSetup.get("DE", TaskConfig.ID10M.value),
    LanguageTestSetup.get("IT", TaskConfig.ID10M.value),
    LanguageTestSetup.get("IT", TaskConfig.DODIOM.value),
    LanguageTestSetup.get("ES", TaskConfig.ID10M.value),
    LanguageTestSetup.get("JP", TaskConfig.OPEN_MWE.value),
    LanguageTestSetup.get("TR", TaskConfig.DODIOM.value),
]

def is_directory_empty(directory_path: Path):
    """
    Returns True iff the directory is empty
    """
    if not directory_path.is_dir():
        return False 
    return not any(directory_path.iterdir())

def get_only_default_tests_of_specific_language(lang: str) -> List[LanguageTestSetup]:
    assert lang in LANGUAGE_TO_CODE, f"Language name must be one of the following {list(LANGUAGE_TO_CODE.keys())} (case sensitive)"
    return [t for t in DEFAULT_TESTS if t.language_code == LANGUAGE_MAP[lang]]


def is_huggingface_model_folder(folder: Path) -> bool:
    """
    Returns True if the folder contains a minimal, ready-to-run Hugging Face model
    (with config.json and model.safetensors), False otherwise.
    """
    if not isinstance(folder, Path):
        folder = Path(folder)
    if not folder.is_dir():
        return False

    required_files = ["config.json", "model.safetensors"]
    return all((folder / fname).is_file() for fname in required_files)

def replace_separators_with_underscore(s):
    """Replaces all / and \\ in the string with _ in one pass."""
    return re.sub(r'[\\/]', '_', s)

@dataclass
class Test:
    train_language: str  # e.g., "english"
    source: str
    model_name: str
    safe_model_name: str #This is 'model_name' with slashes and backslashes replaced with an underscore (to use in paths)
    results_dir: Path = Path("results")
    seed: int = -1
    test_setups: List[LanguageTestSetup] = None
    encoder_trainer: EncoderTrainer = field(default=None, repr=False)
    train_out_path: Path = field(default=None, repr=False)

    def __init__(self, **kwargs):
        if 'results_dir' in kwargs:
            self.results_dir = Path(kwargs.pop('results_dir'))
        
        for key, value in kwargs.items():
            setattr(self, key, value)

        self.safe_model_name = replace_separators_with_underscore(self.model_name)

        if "test_setups" not in kwargs:
            if self.train_language in self.model_name:
                print(f"Language {self.train_language} detected in model name {self.model_name} - thus only testing on this language. If you want to change this behavior, explictly set the variable test_setups")
                self.set_only_tests_of_train_language()
            else:
               self.test_setups = DEFAULT_TESTS

        #Note this should be after setting relevant fields (e.g. train_language) otherwise __str_() would fail
        if 'train_out_path' in kwargs:
            self.train_out_path = Path(kwargs.pop('train_out_path'))
        else:
            self.update_train_out_path_to_default()

    def __str__(self) -> str:
        """
        Returns a string representation for the test.
    
        ⚠ Note:
        - This is used, amongt other things, for storing file paths and checking execution.
        """
        return f"train_lang_{LANGUAGE_MAP[self.train_language].lower()}_src_{self.source}_{self.safe_model_name}_seed_{self.seed}"
        
    def to_comparison_str(self):
        """
        Returns a file-safe string identifier for the test.
    
        ⚠ Note:
        - This is used for checking execution.
        - Slashes and backslashes in `model_name` are replaced with underscores.
        - Since the original model name cannot be perfectly recovered, this method should NOT be used for deserialization.
        """
        return f"{self.__str__()}_test_setups_{str(self.test_setups)}"


    def __hash__(self):
        return hash(self.__str__())
    
    def __eq__(self, other):
        if not isinstance(other, Test):
            raise TypeError("Can only compare instances of same type")
        return self.to_comparison_str() == other.to_comparison_str()

    def update_train_out_path_to_default(self):
        self.train_out_path = Path(self.__str__())
       
    def set_only_tests_of_train_language(self):
        self.test_setups =[t for t in DEFAULT_TESTS if t.language_code == LANGUAGE_MAP[self.train_language]]

    def train(self, should_delete_checkpoints: bool = True):
        assert self.seed>0, "Need to set seed before calling train()"
        print(f"{datetime.now()} Training gpu {GPU} {self}")
        data = get_data(lang=self.train_language, task=self.source)
        train_data, test_data = data["train"], data["test"]
        if test_data.empty:
            print(f"{datetime.now()} Got empty test dataset, using validation set instead")
            test_data = data["validation"]

        self.encoder_trainer = EncoderTrainer(
            seed=self.seed,
            model_name=self.model_name,
            output_dir=self.train_out_path,
            train_data=train_data,
            test_data=test_data,
            # data=data,
            )
        self.encoder_trainer.train(should_delete_checkpoints=should_delete_checkpoints, custom_train_args=model_to_train_args[self.model_name])

    def test(self):
        print(f"{datetime.now()} Testing: {self.test_setups}")
        results = {}
        for test_setup in self.test_setups:
            try:
                test_data = get_data(lang=LANGUAGE_MAP[test_setup.language_code], task=test_setup.source)["test"]
                results[str(test_setup)] = self.encoder_trainer.test(test_data=test_data, return_metrics=["eval_f1"])
            except Exception as e:
                print(f"{datetime.now()} An exception occurred in {test_setup}. {e}\n{traceback.format_exc()}")
        
        result_filename = f"{self.__str__()}.json" 
        try:
            self.results_dir.mkdir(parents=True, exist_ok=True)
            result_out_file_path = Path(self.results_dir, result_filename)
        except Exception as e:
            print(f"{datetime.now()} An exception occurred while trying to create {self.results_dir}. Writing to local dir {os.getcwd()}. Exception: {e}\n{traceback.format_exc()}")
            result_out_file_path = os.getcwd()
        
        print(f"{datetime.now()} Saving results to {result_out_file_path}")
        with open(result_out_file_path, 'w') as f:
            json.dump(results, f, indent=4)

        general_process_results_directory_to_csv(str(self.results_dir), self.safe_model_name, DEFAULT_RESULTS_FILE)
                
    def set_existing_model(self):
        #TODO - If keeping this, need to change because it's not alwys BertForTokenClassification (probably better to delete this logic all together, just don't check for existing models)
        self.encoder_trainer = EncoderTrainer(
            model_name=self.model_name,
            trained_model=BertForTokenClassification.from_pretrained(self.train_out_path)
            )

    def run(self):
        if self.test_setups is None:
            if self.train_language in self.model_name:
                print(f"Language {self.train_language} detected in model name {self.model_name} - thus only testing on this language. If you want to change this behavior, explictly set the variable test_setups")
                self.set_only_tests_of_train_language()
            else:
               self.test_setups = DEFAULT_TESTS

        print(f"{datetime.now()} Running (gpu {GPU}) test {self}")

        if is_huggingface_model_folder(self.train_out_path):
            print(f"{datetime.now()} Found an existing model for this test, so skipping train. Model location: {self.train_out_path}")
            self.set_existing_model()
        else:
            self.train()
        self.test()
        if is_directory_empty(self.train_out_path):
            os.rmdir(self.train_out_path)

In [5]:
@file_mutex(lambda: EXECUTED_TESTS_FILE + ".lock")
def get_executed_tests_ids() -> set[str]:
    """
    Loads previously executed Test instances from disk
    Returns an empty set if the file does not exist.
    """
    executed_tests_file_type_path = Path(EXECUTED_TESTS_FILE)
    if executed_tests_file_type_path.exists():
        return set(json.loads(executed_tests_file_type_path.read_text()))
    return set()

def save_executed_tests(executed_tests_ids: set[str]):
    """
    Saves the current set of executed Test ids to disk.
    """
    executed_tests_file_type_path = Path(EXECUTED_TESTS_FILE)
    executed_tests_file_type_path.write_text(json.dumps(list(executed_tests_ids), indent=2))

@file_mutex(lambda *args, **kwargs: EXECUTED_TESTS_FILE + ".lock")
def add_test_to_executed(cur_test: Test, executed_tests_ids: set[str]):
    executed_tests_ids.add(cur_test.to_comparison_str())
    save_executed_tests(executed_tests_ids)

def was_test_executed(cur_test: Test, executed_tests_ids: set[str]) -> bool:
    return cur_test.to_comparison_str() in executed_tests_ids

SEEDS = [5, 7, 42, 123, 1773]
EN_TESTS = set([Test(train_language="english", source=src, model_name=m)
                for src, m in product([TaskConfig.ID10M.value, TaskConfig.MAGPIE.value],
                                               ["bert-base-multilingual-cased", "FacebookAI/xlm-roberta-base"])])
EN_TESTS.update(set([Test(train_language="english", source=src, model_name="FacebookAI/roberta-base",
                test_setups = get_only_default_tests_of_specific_language("english")) for src in [TaskConfig.ID10M.value, TaskConfig.MAGPIE.value]]))

DE_TESTS = set([Test(train_language="german", source=TaskConfig.ID10M.value, model_name=m) for m in 
                ["bert-base-multilingual-cased", "FacebookAI/xlm-roberta-base", "google-bert/bert-base-german-cased"]])

# We stopped using DODIOM after surveying early results, keeping it commented for now. need to delete in the future
# IT_TESTS = set([Test(train_language="italian", source=src, model_name=m)
#                 for src, m in product([TaskConfig.ID10M.value, TaskConfig.DODIOM.value],
#                                       ["bert-base-multilingual-cased", "FacebookAI/xlm-roberta-base", "dbmdz/bert-base-italian-cased"])])

IT_TESTS = set([Test(train_language="italian", source=TaskConfig.ID10M.value, model_name=m)
                for m in ["bert-base-multilingual-cased", "FacebookAI/xlm-roberta-base", "dbmdz/bert-base-italian-cased"]])

ES_TESTS = set([Test(train_language="spanish", source=TaskConfig.ID10M.value, model_name=m)
                for m in ["bert-base-multilingual-cased", "FacebookAI/xlm-roberta-base", "dccuchile/bert-base-spanish-wwm-cased"]])

# We stopped using JP (IDIOM) after surveying early
JP_TESTS = set([Test(train_language="japanese", source=TaskConfig.OPEN_MWE.value, model_name=m)
                for m in ["bert-base-multilingual-cased", "FacebookAI/xlm-roberta-base", "tohoku-nlp/bert-base-japanese"]])

# We stopped using DODIOM after surveying early results, keeping it commented for now. need to delete in the future
# TR_TESTS = set([Test(train_language="turkish", source=TaskConfig.DODIOM.value, model_name=m)
#                 for m in ["bert-base-multilingual-cased", "FacebookAI/xlm-roberta-base", "dbmdz/bert-base-turkish-cased"]])

#Languages we dropped due to early results: dutch, chinese,
OTHER_TESTS = set([Test(train_language=lang, source=TaskConfig.ID10M.value, model_name=m)
                for lang, m in product(["french", "polish", "portuguese"],
                                      ["bert-base-multilingual-cased", "FacebookAI/xlm-roberta-base"])])

 
TESTS = set()
TESTS.update(EN_TESTS, JP_TESTS, IT_TESTS, ES_TESTS, DE_TESTS, OTHER_TESTS)#) # TR_TESTS


executed_tests_ids = get_executed_tests_ids()
print(f"{datetime.now()} Loaded {len(executed_tests_ids)} executed tests")
failed_tests = []

tests_to_run = TESTS
for i_seed, cur_seed in enumerate(SEEDS):
    for i_tst, cur_test in enumerate(tests_to_run):
        print(f"{datetime.now()} Seed {i_seed+1}/{len(SEEDS)} Test {i_tst+1}/{len(tests_to_run)}")
        cur_test.seed = cur_seed
        cur_test.update_train_out_path_to_default()
        if was_test_executed(cur_test, executed_tests_ids):
            print(f"{datetime.now()} Skipping test because already executed {cur_test}")
            continue
        try:
            cur_test.run()
            add_test_to_executed(cur_test, executed_tests_ids)
        except Exception as e:
            print(f"{datetime.now()} An exception occurred in {cur_test}. {e}\n{traceback.format_exc()}")
            failed_tests.append(cur_test)

Language german detected in model name google-bert/bert-base-german-cased - thus only testing on this language. If you want to change this behavior, explictly set the variable test_setups
Language italian detected in model name dbmdz/bert-base-italian-cased - thus only testing on this language. If you want to change this behavior, explictly set the variable test_setups
Language spanish detected in model name dccuchile/bert-base-spanish-wwm-cased - thus only testing on this language. If you want to change this behavior, explictly set the variable test_setups
Language japanese detected in model name tohoku-nlp/bert-base-japanese - thus only testing on this language. If you want to change this behavior, explictly set the variable test_setups
2025-07-20 19:50:04.536411 Loaded 157 executed tests
2025-07-20 19:50:04.536612 Seed 1/5 Test 1/24
2025-07-20 19:50:04.536661 Skipping test because already executed train_lang_en_src_magpie_FacebookAI_xlm-roberta-base_seed_5
2025-07-20 19:50:04.536675

In [6]:
print(f"{datetime.now()} {len(failed_tests)} Tests failed")
if len(failed_tests)>0:
    print(f"Failed tests: {failed_tests}")

2025-07-20 19:50:04.544218 0 Tests failed


In [7]:
print(f"{datetime.now()} FIN")

2025-07-20 19:50:04.548377 FIN
