In [None]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" #This should be before import of tokenizer
import sys
from datetime import datetime 
from pathlib import Path
from enum import Enum
from typing import Dict, List, Union
from itertools import product
from collections import defaultdict
import pandas as pd
import random
from transformers import (
    BertForTokenClassification,
    AutoModelForTokenClassification,
    Trainer,
    TrainingArguments,
    AutoTokenizer,
    RobertaTokenizerFast,
    DataCollatorForTokenClassification,
    set_seed,
)
import numpy as np
from datasets import Dataset, DatasetDict
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
import json


# Get the project root directory
current_dir = Path.cwd()
project_root = current_dir.parent  # Go up one level from encoder_fine_tuning to cross-lingual-idioms
sys.path.append(str(project_root))
from src.utils import get_data

GPU = "0"
os.environ["CUDA_VISIBLE_DEVICES"] = GPU

TOTAL_RECORDS_NUM = 60000
SEEDS = [5, 7]#, 123, 1773] #42 is done
MODEL_NAME = "bert-base-multilingual-cased"
JUMP = 0.1 #Usually 0.1 or 0.2
JUMP_STR = str(JUMP).replace('.', '_') #To be used in file paths


print(f"{datetime.now()} Start gpu {GPU}")

In [None]:
########################### Set up ###########################

class TaskConfig(Enum):
    DODIOM = "dodiom"
    ID10M = "id10m"
    OPEN_MWE = "open_mwe"
    MAGPIE = "magpie"

def get_idiom_only_list() -> List[TaskConfig]:
    return [TaskConfig.ID10M.value]


LANGUAGE_TO_CODE: Dict[str, str] = {
    "english": "EN",
    "spanish": "ES",
    "german": "DE",
    "japanese": "JP",
    "turkish": "TR",
    "chinese": "ZH",
    "french": "FR",
    "polish": "PL",
    "italian": "IT",
    "dutch": "NL",
    "portuguese": "PT"
}

def get_language_map() -> Dict[str, str]:
    CODE_TO_LANGUAGE: Dict[str, str] = {v: k for k, v in LANGUAGE_TO_CODE.items()}
    
    # Merge into a single bidirectional object
    return {**LANGUAGE_TO_CODE, **CODE_TO_LANGUAGE}

LANGUAGE_MAP = get_language_map()


LANG_TO_SOURCES = defaultdict(get_idiom_only_list)
LANG_TO_SOURCES["turkish"] = [TaskConfig.DODIOM.value]
LANG_TO_SOURCES["japanese"] = [TaskConfig.OPEN_MWE.value]
# Optional: LANG_TO_SOURCES["english"].append(TaskConfig.MAGPIE.value)

OOD_ONE_OUT_LANG = "japanese"
OTHER_LANGS = list(LANGUAGE_TO_CODE.keys())
OTHER_LANGS.remove(OOD_ONE_OUT_LANG)
ODD_ONE_DATASET = get_data(lang=OOD_ONE_OUT_LANG, task=LANG_TO_SOURCES[OOD_ONE_OUT_LANG][0]) #Taking only from the first source of OOD_ONE_OUT_LANG
OTHER_LANGS_TO_DATAFRAMES_TRAIN_ONLY = defaultdict(lambda: pd.DataFrame())
RESULTS_DIR = Path(f"odd_one_out_{OOD_ONE_OUT_LANG}_records_num_{TOTAL_RECORDS_NUM}_jump_{JUMP_STR}_results")
RESULTS_DIR.mkdir(exist_ok=True) 

for otra_lengua in OTHER_LANGS:
    for src in LANG_TO_SOURCES[otra_lengua]:
        cur_df = get_data(lang=otra_lengua, task=src)["train"]
        OTHER_LANGS_TO_DATAFRAMES_TRAIN_ONLY[otra_lengua] = pd.concat([OTHER_LANGS_TO_DATAFRAMES_TRAIN_ONLY[otra_lengua], cur_df])

In [None]:
########################### auxiliary ###########################
BEST_F1_PER_TRAINING = -1

def compute_metrics(p):
    predictions, labels = p
    predictions = predictions.argmax(-1)

    true_labels, pred_labels = [], []
    for pred_seq, label_seq in zip(predictions, labels):
        for pred_id, label_id in zip(pred_seq, label_seq):
            if label_id == -100:
                continue
            true_labels.append(id2label[label_id])
            pred_labels.append(id2label[pred_id])
        
    cur_f1 = f1_score(true_labels, pred_labels, average="macro", zero_division=0)
    global BEST_F1_PER_TRAINING
    BEST_F1_PER_TRAINING = max(BEST_F1_PER_TRAINING, cur_f1)
    return {
        "eval_f1": cur_f1,
        "precision": precision_score(true_labels, pred_labels, average="macro", zero_division=0),
        "recall": recall_score(true_labels, pred_labels, average="macro", zero_division=0),
        "accuracy": accuracy_score(true_labels, pred_labels),
    }


def encode_labels(example):
    example["labels"] = [label2id[tag] for tag in example["tags"]]
    return example


def tokenize_and_align_labels(example):
    tokenized = tokenizer(example["tokens"], is_split_into_words=True, truncation=True)
    word_ids = tokenized.word_ids()
    previous_word_idx = None
    labels = []
    for word_idx in word_ids:
        if word_idx is None:
            labels.append(-100)
        elif word_idx != previous_word_idx:
            labels.append(example["labels"][word_idx])
        else:
            labels.append(-100)
        previous_word_idx = word_idx
    tokenized["labels"] = labels
    return tokenized


def tokenize_and_align_labels_slow(example):
    # For slow tokenizers, we need to align manually
    # Tokenize each word and assign the label to all resulting tokens
    tokens = example["tokens"]
    labels = example["labels"]
    tokenized_inputs = tokenizer(
        tokens,
        is_split_into_words=True,  # For slow tokenizer, this works as expected
        truncation=True,
        padding='max_length',
        max_length=128,  # or whatever length you need
    )
    # Manually align labels
    word_ids = []
    cur = 0
    for token in tokenized_inputs["input_ids"]:
        # For slow tokenizers, you can align using the tokens list
        # Here, we assign the label of the current word to all its sub-tokens
        # This is a simplification; you may need to adjust for special tokens
        if cur < len(labels):
            word_ids.append(labels[cur])
            cur += 1
        else:
            word_ids.append(-100)  # Padding label for special tokens
    tokenized_inputs["labels"] = word_ids[:len(tokenized_inputs["input_ids"])]
    return tokenized_inputs

def get_tokenized_dataset(lang: str, model_name: str, dataset: DatasetDict) -> DatasetDict:
    if "japanese" in model_name or lang=="japanese":
        return dataset.map(tokenize_and_align_labels_slow, batched=False)
    else:
        return dataset.map(tokenize_and_align_labels)


def get_model_by_name(lang: str, model_name: str, label2id: Dict[str, int], id2label: Dict[int, str], hidden_dropout_prob: float=0.5,
                      attention_probs_dropout_prob: float=0.5) -> Union[BertForTokenClassification, AutoModelForTokenClassification]:
    if model_name in ["FacebookAI/xlm-roberta-base", "FacebookAI/roberta-base"] or "japanese" in model_name or lang=="japanese":
        return AutoModelForTokenClassification.from_pretrained(model_name,
            num_labels=len(label2id),
            id2label=id2label,
            label2id=label2id,
            hidden_dropout_prob=hidden_dropout_prob,
            attention_probs_dropout_prob=attention_probs_dropout_prob,)
    else:
        return BertForTokenClassification.from_pretrained(model_name,
            num_labels=len(label2id),
            id2label=id2label,
            label2id=label2id,
            hidden_dropout_prob=hidden_dropout_prob,
            attention_probs_dropout_prob=attention_probs_dropout_prob,
        )

def get_tokenizer_by_model_name(lang: str, model_name: str) -> Union[AutoTokenizer, RobertaTokenizerFast]:
    if model_name == "FacebookAI/roberta-base":
        return RobertaTokenizerFast.from_pretrained('roberta-base', add_prefix_space=True)
    elif "japanese" in model_name or lang=="japanese":
        return AutoTokenizer.from_pretrained(model_name, use_fast=False)
    else:
        return AutoTokenizer.from_pretrained(model_name)

In [None]:
########################### Logic ###########################

def finetune(train_df: pd.DataFrame, cur_seed: int, model_name: str = MODEL_NAME):
    set_seed(cur_seed)
    train_dataset = Dataset.from_pandas(train_df[["tokens", "tags"]])
    test_dataset = Dataset.from_pandas(ODD_ONE_DATASET["test"][["tokens", "tags"]])
    dataset = DatasetDict({
        "train": train_dataset,
        "test": test_dataset,
    })

    global label2id, id2label, tokenizer  # used in encode_labels and compute_metrics
    unique_tags = sorted(set(tag for row in dataset["train"]["tags"] for tag in row))
    label2id = {tag: i for i, tag in enumerate(unique_tags)}
    id2label = {i: tag for tag, i in label2id.items()}
    tokenizer = get_tokenizer_by_model_name(OOD_ONE_OUT_LANG, MODEL_NAME)#AutoTokenizer.from_pretrained(model_name)

    dataset = dataset.map(encode_labels)
    
    tokenized_dataset = get_tokenized_dataset(OOD_ONE_OUT_LANG, MODEL_NAME, dataset) #tokenized_dataset = dataset.map(tokenize_and_align_labels)

    model = get_model_by_name(OOD_ONE_OUT_LANG, MODEL_NAME, label2id, id2label)
  

    training_args = TrainingArguments(
        output_dir="odd_one_out_out",
        eval_strategy="epoch",
        save_strategy="no",#"epoch",
        learning_rate=2e-5,
        per_device_train_batch_size=32,
        per_device_eval_batch_size=32,
        num_train_epochs=20,
        weight_decay=0.01,
        warmup_steps=0,
        logging_dir="./odd_one_out_logs",
        logging_strategy="epoch",
        logging_steps=1000,
        # save_total_limit=2,
        seed=cur_seed,
        dataloader_num_workers=4,
        disable_tqdm=False,
        report_to=[],
        gradient_accumulation_steps=4,
        gradient_checkpointing=False,
        fp16=True,
        optim="adamw_torch",
        adam_beta1=0.9,
        adam_beta2=0.999,
        adam_epsilon=1e-8,
        max_grad_norm=1.0,
        # load_best_model_at_end=True,
        # metric_for_best_model="eval_f1",
        # greater_is_better=True,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset["train"],
        eval_dataset=tokenized_dataset["test"],
        processing_class=tokenizer,#tokenizer=tokenizer,
        data_collator=DataCollatorForTokenClassification(tokenizer),
        compute_metrics=compute_metrics
    )

    trainer.train()

def decimal_to_percentage(decimal: float) -> int:
    if not (0 <= decimal <= 1):
        raise ValueError("Input must be between 0 and 1")
    return int(round(decimal * 100))

def get_run_result_path(seed: int, percentage: int) -> str:
    return Path(RESULTS_DIR, f"seed_{cur_seed}_odd_one_out_{OOD_ONE_OUT_LANG}_{percentage}_results.json")


In [None]:
results_summary = {}
perc_range = np.arange(0, 1+JUMP, JUMP)
for i_seed, cur_seed in enumerate(SEEDS):
    random.seed(cur_seed)
    for i_perc, odd_one_percentage in enumerate(perc_range):
        print(f"{datetime.now()} Odd one out language {OOD_ONE_OUT_LANG} percentage {odd_one_percentage} ({i_perc+1}/{len(perc_range)}) cur_seed {cur_seed} ({i_seed+1}/{len(SEEDS)})")
        run_result_path = get_run_result_path(cur_seed, decimal_to_percentage(odd_one_percentage))
        if os.path.exists(run_result_path) and os.path.getsize(run_result_path) > 0:
            print(f"{datetime.now()} Non-empty results file exists for this test {run_result_path} Skipping")
            continue

        odd_one_samples_num = int(odd_one_percentage*TOTAL_RECORDS_NUM)
        assert odd_one_samples_num<=len(ODD_ONE_DATASET["train"]), f'{OOD_ONE_OUT_LANG} has {len(ODD_ONE_DATASET["train"])} samples but {odd_one_samples_num} needed'
        odd_one_out_samples = ODD_ONE_DATASET["train"].sample(n=odd_one_samples_num, random_state=cur_seed)
        other_langs_samples_count = TOTAL_RECORDS_NUM - len(odd_one_out_samples)

        #TODO - Make sure that every language has enough samples to give. If not, adjust accordingly. If all of the together don't have enough - throw excpetion
        n_dfs = len(OTHER_LANGS_TO_DATAFRAMES_TRAIN_ONLY)
        base_num = other_langs_samples_count // n_dfs
        lang_to_samples_num = {other_l: base_num for other_l in OTHER_LANGS}
        remainder = other_langs_samples_count % n_dfs
        if remainder>0:
            langs_to_add = random.sample(OTHER_LANGS, remainder)
            for l in langs_to_add:
                lang_to_samples_num[l]+=1

        print(f"{datetime.now()} Number of samples per lang: {OOD_ONE_OUT_LANG} {len(odd_one_out_samples)} {lang_to_samples_num}")

        train_dfs = []
        for l in OTHER_LANGS:
            train_dfs.append(OTHER_LANGS_TO_DATAFRAMES_TRAIN_ONLY[l].sample(n=lang_to_samples_num[l], random_state=cur_seed))
            
        train_dfs.append(odd_one_out_samples)
        concatenated_train_df = pd.concat(train_dfs)    

        ###### ADDED
        # Shuffle the concatenated dataframe
        concatenated_train_df = concatenated_train_df.sample(frac=1, random_state=cur_seed).reset_index(drop=True)
        ######

        assert len(concatenated_train_df) == TOTAL_RECORDS_NUM, f"Bug in code! Should've {TOTAL_RECORDS_NUM} samples, but have {len(concatenated_train_df)}"
        global BEST_F1_PER_TRAINING
        BEST_F1_PER_TRAINING = -1
        finetune(concatenated_train_df, cur_seed)
        with open(run_result_path, 'w') as json_file:
            json.dump(BEST_F1_PER_TRAINING, json_file, indent=4) 
        
        results_summary[f"seed_{cur_seed}_odd_lang_{OOD_ONE_OUT_LANG}_percentage_{decimal_to_percentage(odd_one_percentage)}"] = BEST_F1_PER_TRAINING

if len(results_summary)>0:
    with open(Path(RESULTS_DIR, f"odd_one_out_seed_{cur_seed}_{OOD_ONE_OUT_LANG}_results.json"), 'w') as json_file:
        json.dump(results_summary, json_file, indent=4) 

In [None]:
# SHow the last 10 samples
for i, row in concatenated_train_df.tail(20).iterrows():
    print(row["sentence"])

In [None]:
print(f"{datetime.now()} FIN")