# Overview

This notebook compares DeepL and Aya translations over a selected sub-set of SNOMED concepts.

1. We run DeepL over the entire subset.
2. We evaluate Aya with a simple prompt over the entire subset.
3. We evaluate Aya with a richer prompt, constructed using RAG techniques over the terminology.
4. We perform a minimal fine-tune of Aya and evaluate whether translation quality has improved.
5. We export a grid of results for analysis in Excel.

In [None]:
# Set to True if we want to make translation evaluation case-insensitive
ignore_case = True

In [None]:
# set to True if running locally, False if running on AWS SageMaker
local_run = True

In [None]:
# Set to True if we want to re-use an existing fine-tuning dataset
use_existing_sft_dataset = True

# 1. Setup

Install required dependencies

In [None]:
if not local_run:
    !pip install transformers
    !pip install deepl
    !pip install tqdm
    !pip install evaluate
    !pip install termcolor
    !pip install Levenshtein
    !pip install nltk
    !pip install cer
    !pip install accelerate
    !pip install wandb
    !pip install scikit-learn

In [None]:
if not local_run:
    # Redirect cache so we don't fill the disk
    import os
    os.environ['HF_HOME'] = '/home/ec2-user/SageMaker/cache/'

In [None]:
import pandas as pd
import torch
from transformers import (
    AutoModelForSeq2SeqLM, AutoTokenizer, BitsAndBytesConfig, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, 
    Seq2SeqTrainer, TrainingArguments, Trainer
)
from datasets import Dataset, load_dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from snomed_graph import *
import getpass
import deepl
from tqdm.notebook import tqdm
import json
import numpy as np
import evaluate
from termcolor import colored
from collections import namedtuple
import wandb
from copy import deepcopy
from operator import __or__
from functools import reduce
from ast import literal_eval
from Levenshtein import ratio
from itertools import chain
from functools import partial
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
from sklearn.ensemble import RandomForestRegressor
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.metrics import mean_squared_error

In [None]:
# Enter the DeepL API key here
DEEPL_AUTH_KEY = getpass.getpass()

In [None]:
# Filepaths
PATH_TO_SERIALIZED_SNOMED_GRAPH = "./data/snomed_graph/full_concept_graph.gml"
PATH_TO_TRANSLATION_SAMPLES = "./data/prepared_translation_data/samples.csv"
PATH_TO_ALL_TRANSLATION_REFERENCES = "./data/prepared_translation_data/all_translations.csv"
PATH_TO_DEEPL_TRANSLATION_RESULTS = "./data/cache/deepl_results.json"
PATH_TO_AYA_VANILLA_TRANSLATION_RESULTS = "./data/cache/aya_results_vanilla.json"
PATH_TO_AYA_ENRICHED_TRANSLATION_RESULTS = "./data/cache/aya_results_enriched.json"
TRANSLATIONS_OUTPUT_PATH = "./data/translation_outputs/translations.csv"
GRID_OUTPUT_PATH = "./data/translation_outputs/grid.csv"
PATH_TO_SFT_DATASET = "./data/sft_dataset"
ADAPTOR_PATH = "./models/adaptors/aya_finetuned"

# The Aya checkpoint on HuggingFace
AYA_CHECKPOINT = "CohereForAI/aya-101"

# Reproducible experiments
RANDOM_SEED = 42

# The number of examples to use to evaluate the fine-tuning results.
SFT_EVAL_EXAMPLES = 500

# The % of selected fine-tuning examples to reserve for testing.
SFT_TEST_PCT = 0.9

# The languages to fine-tune over
SFT_LANGS = ["Dutch"]

This is the reference data we'll use:

In [None]:
langcodes = {
    "Dutch": "NL",
    "Estonian": "ET",
    "Korean": "KO",
    "Swedish": "SV",
}

In [None]:
hierarchies_in_use = [
    "substance",
    "body structure",
    "finding",
    "disorder",
    "procedure",
    "morphologic abnormality"
]

Note that some "low value" attributes have been removed because the relationships are unlikely to yield any value to a translator.

In [None]:
important_attributes = {
    # 'Access (attribute)',
    # 'After (attribute)',
    'Associated finding (attribute)',
    'Associated morphology (attribute)',
    'Associated procedure (attribute)',
    'Associated with (attribute)',
    'Before (attribute)',
    'Causative agent (attribute)',
    'Characterizes (attribute)',
    # 'Clinical course (attribute)',
    'Component (attribute)',
    'Direct device (attribute)',
    'Direct morphology (attribute)',
    'Direct site (attribute)',
    'Direct substance (attribute)',
    'Due to (attribute)',
    'During (attribute)',
    # 'Finding context (attribute)',
    'Finding informer (attribute)',
    'Finding method (attribute)',
    'Finding site (attribute)',
    'Has absorbability (attribute)',
    'Has active ingredient (attribute)',
    'Has basic dose form (attribute)',
    'Has basis of strength substance (attribute)',
    'Has coating material (attribute)',
    'Has compositional material (attribute)',
    'Has concentration strength denominator unit (attribute)',
    'Has concentration strength numerator unit (attribute)',
    'Has device intended site (attribute)',
    'Has disposition (attribute)',
    'Has dose form administration method (attribute)',
    'Has dose form intended site (attribute)',
    'Has dose form release characteristic (attribute)',
    'Has dose form transformation (attribute)',
    'Has filling (attribute)',
    'Has focus (attribute)',
    'Has ingredient qualitative strength (attribute)',
    'Has intent (attribute)',
    # 'Has interpretation (attribute)',
    'Has manufactured dose form (attribute)',
    'Has precise active ingredient (attribute)',
    'Has presentation strength denominator unit (attribute)',
    'Has presentation strength numerator unit (attribute)',
    'Has realization (attribute)',
    'Has specimen (attribute)',
    'Has state of matter (attribute)',
    'Has surface texture (attribute)',
    'Has target population (attribute)',
    'Has unit of presentation (attribute)',
    'Indirect device (attribute)',
    'Indirect morphology (attribute)',
    'Inherent location (attribute)',
    'Inheres in (attribute)',
    'Interprets (attribute)',
    # 'Is a (attribute)',
    'Is modification of (attribute)',
    'Is sterile (attribute)',
    'Laterality (attribute)',
    'Measurement method (attribute)',
    'Method (attribute)',
    'Occurrence (attribute)',
    'Pathological process (attribute)',
    'Plays role (attribute)',
    'Precondition (attribute)',
    'Priority (attribute)',
    'Procedure context (attribute)',
    'Procedure device (attribute)',
    'Procedure morphology (attribute)',
    'Procedure site (attribute)',
    'Procedure site - Direct (attribute)',
    'Procedure site - Indirect (attribute)',
    'Process acts on (attribute)',
    'Process duration (attribute)',
    'Process extends to (attribute)',
    'Process output (attribute)',
    'Property (attribute)',
    'Recipient category (attribute)',
    'Relative to (attribute)',
    'Relative to part of (attribute)',
    'Revision status (attribute)',
    'Route of administration (attribute)',
    # 'Scale type (attribute)',
    # 'Severity (attribute)',
    'Specimen procedure (attribute)',
    'Specimen source identity (attribute)',
    'Specimen source morphology (attribute)',
    'Specimen source topography (attribute)',
    'Specimen substance (attribute)',
    # 'Subject relationship context (attribute)',
    'Surgical approach (attribute)',
    'Technique (attribute)',
    # 'Temporal context (attribute)',
    # 'Temporally related to (attribute)',
    # 'Time aspect (attribute)',
    # 'Units (attribute)',
    'Using access device (attribute)',
    'Using device (attribute)',
    'Using energy (attribute)',
    'Using substance (attribute)'
}

# 2. Load the data

## 2.1 Load the concepts to translate

In [None]:
# Columns are: sctid, fsn, hierarchy, language, context_tier, depth_tier, concept_length_bucket, translations
all_df = (
    pd.read_csv(PATH_TO_TRANSLATION_SAMPLES)
    .set_index(["sctid", "language"])
)

all_df.reference_translations = all_df.reference_translations.apply(literal_eval)

all_df.shape[0]

## 2.2 Load the full set of reference translations

In [None]:
# Columns are: sctid, fsn, hierarchy, language, context_tier, depth_tier, concept_length_bucket, translations
ref_df = (
    pd.read_csv(PATH_TO_ALL_TRANSLATION_REFERENCES)
    .set_index(["sctid", "language"])
)

ref_df.reference_translations = ref_df.reference_translations.apply(lambda x: literal_eval(x) if x is not np.nan else pd.NA)

ref_df.shape[0]

## 2.3 Load the SNOMED graph object

In [None]:
G = SnomedGraph.from_serialized(PATH_TO_SERIALIZED_SNOMED_GRAPH)

# 3. Evaluation Functions

In [None]:
# Google BLEU
# Max of precision, recall of all ngrams (of 1-4 tokens)
# Higher is better
# https://huggingface.co/spaces/evaluate-metric/google_bleu
google_bleu = evaluate.load("google_bleu")

# CharacTER
# Roughly: min # char edits required to match pred to ref, normalized by pred len
# Lower is better
# https://huggingface.co/spaces/evaluate-metric/character
character = evaluate.load("character")

In [None]:
# Matches evaluate library outputs
def exact_match(predictions, references):
    N = len(predictions)
    n = 0
    for p, r in zip(predictions, references):
        if p in r:
            n += 1
    return {'exact_match': float(n)/N}

In [None]:
# Levenshtein Ratio
# 1 - [Levenshtein Dist] / [Sum of lengths]
# Higher is better
# https://rapidfuzz.github.io/Levenshtein/levenshtein.html#ratio
def levenshtein_ratio(predictions, references):
    ratios = [
        np.max([ratio(p, r) for r in refs])
        for p, refs in zip(predictions, references)
    ]
    return {'levenshtein_ratio': np.mean(ratios)}

In [None]:
def evaluate_translations(row_or_df, target_column, ignore_case):
    if isinstance(row_or_df, pd.DataFrame):
        assert target_column in row_or_df.columns    
        candidates = list(row_or_df.to_dict()[target_column].values())
        references = row_or_df.reference_translations.tolist()
    else:
        candidates = [getattr(row_or_df, target_column)]
        references = [row_or_df.reference_translations]
    if ignore_case:
        candidates = [c.lower() for c in candidates]
        references = [[r.lower() for r in refs] for refs in references]
    results = [
        exact_match(predictions=candidates, references=references),
        levenshtein_ratio(predictions=candidates, references=references),
        google_bleu.compute(predictions=candidates, references=references),
        character.compute(predictions=candidates, references=references),
    ]
    results = reduce(__or__, results, dict())
    return results

# 4. Generate baseline translations with DeepL

To save costs, we cache translations.  This means that whenever we re-run the code, we only translate new samples.

In [None]:
translator = deepl.Translator(DEEPL_AUTH_KEY)

def translate_with_deepl(df, G):
    
    with open(PATH_TO_DEEPL_TRANSLATION_RESULTS, "r") as f:
        deepl_results = json.load(f)
    
    for it, row in enumerate(tqdm(df.itertuples(), total=df.shape[0])):
        sctid, language = row.Index
        langcode = langcodes[language]
        source_concept = G.get_concept_details(sctid)
        source_preferred_term = source_concept.fsn.replace(f"({source_concept.hierarchy})", "").strip()
        key = str(sctid) + "_" + language
        try:
            yield deepl_results[key]
        except KeyError:
            deepl_result = translator.translate_text(source_preferred_term, target_lang=langcode)
            deepl_results[key] = deepl_result.text
            yield deepl_result.text
        if it % 100 == 0:
            with open(PATH_TO_DEEPL_TRANSLATION_RESULTS, "w") as f:
                json.dump(deepl_results, f)

    with open(PATH_TO_DEEPL_TRANSLATION_RESULTS, "w") as f:
        json.dump(deepl_results, f)

In [None]:
all_df["deepl_translation"] = list(translate_with_deepl(all_df, G))

In [None]:
evaluate_translations(all_df, "deepl_translation", ignore_case=ignore_case)

# 5. Load Aya

In [None]:
tokenizer = AutoTokenizer.from_pretrained(AYA_CHECKPOINT)

N.B. T5-derivative models (like Aya) do not support FlashAttention

In [None]:
if local_run:    
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )
    
    aya_model = AutoModelForSeq2SeqLM.from_pretrained(
        AYA_CHECKPOINT, 
        device_map="cuda", 
        quantization_config=bnb_config,
    )
else:
    aya_model = AutoModelForSeq2SeqLM.from_pretrained(AYA_CHECKPOINT, device_map="auto")

In [None]:
# Prevents generation from being prematurely truncated.
# (Otherwise we'll end up using the model default, which could be too small.)
# Currently commented out because it leads to decoding errors during generation.
# Also, most (all) responses are 20 tokens, or less.

# generation_config = deepcopy(aya_model.generation_config)
# generation_config.update(max_new_tokens = 64)
# generation_config.validate()

## 5.1 Pre and post-processing wrapper functions

In [None]:
def aya_postprocessor(result):
    return (
        result
        .replace(tokenizer.eos_token, "")
        .replace(tokenizer.pad_token, "")
        .replace(".", "")
        .strip()
    )

In [None]:
def translate_with_aya(df, prompt_col, results_filepath=None, rebuild=False, save=False):

    if rebuild:
        results = dict()
    else:
        with open(results_filepath, "r") as f:
            results = json.load(f)
    
    for i, row in tqdm(enumerate(df.itertuples()), total=df.shape[0]):
        sctid, language = row.Index
        key = str(sctid) + "_" + language
        try:
            yield results[key]
        except KeyError:
            prompt = getattr(row, prompt_col)
            input = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
            output = aya_model.generate(input, max_new_tokens=256)
            result = tokenizer.decode(output[0])
            result = aya_postprocessor(result)
            results[key] = result
            yield result

        if i % 100 == 0:
            if save:
                with open(results_filepath, "w") as f:
                    json.dump(results, f)

# 6. Evaluate Aya Translations

We use our final prompt template.

## 6.1 Prepare the search index for similar concept retrieval

In [None]:
def generate_similarity_search_keys(df):
    languages = df.index.get_level_values(1).unique()
    keys = dict()
    for l in languages:
        docs = [
            (row.Index[0], row.fsn.replace(f"({row.hierarchy})", "").strip())
            for row in df[df.index.get_level_values(1) == l].itertuples()
        ]
        values, terms = list(zip(*docs))
        vectorizer = CountVectorizer(lowercase=True, stop_words=None, ngram_range=(2,10), binary=True)
        key_matrix = vectorizer.fit_transform(terms)
        keys[l] = (vectorizer, key_matrix, values)
    return keys

In [None]:
keys = generate_similarity_search_keys(ref_df)

In [None]:
def find_similar(keys, row, G, ref_df, k=5, min_score=2, remove_descendants=True, remove_parents=True):
    sctid = row.Index[0]
    language = row.Index[1]
    vectorizer, key_matrix, values = keys[language]
    term = row.fsn.replace(f"({row.hierarchy})", "").strip()
    query = vectorizer.transform([term])
    search = key_matrix.dot(query.T).A.ravel()
    top_k = np.argsort(-search)[0:k+1]
    scores = search[top_k]
    top_k = top_k[scores >= min_score]
    results = set(np.array(values)[top_k])
    results = {r for r in results if ref_df.loc[r, language].has_translation}
    if remove_descendants:
        descendants = {c.sctid for c in G.get_descendants(sctid)}
    else:
        descendants = set()
    if remove_parents:
        parents = {c.sctid for c in G.get_parents(sctid)}
    else:
        parents = set()        
    results = results - {sctid} - descendants - parents
    if results != set():
        concepts = [G.get_concept_details(r) for r in results]
        preferred_terms = [c.fsn.replace(f"({c.hierarchy})", "").strip() for c in concepts]
        reference_translations = [ref_df.loc[r, language].reference_translations[0] for r in results]
        return list(zip(preferred_terms, reference_translations))
    else:
        return list()

In [None]:
# Match function signature for the other prompt compiling functions
find_similar_ = partial(find_similar, keys=keys)

## 6.2 Minimal prompt

Provide Aya with nothing further than a straightforward translation request.

In [None]:
def prepare_minimal_aya_prompt(row, G):
    sctid, language = row.Index
    concept = G.get_full_concept(sctid)
    preferred_term = concept.fsn.replace(f"({concept.hierarchy})", "").strip()
    prompt = f'Translate the following clinical concept into {language}: "{preferred_term}". '
    return prompt    

In [None]:
# You can test the prompt-compilation here
print(prepare_minimal_aya_prompt(next(all_df.sample(1).itertuples()), G))

In [None]:
all_df["minimal_aya_prompt"] = [prepare_minimal_aya_prompt(row, G) for row in tqdm(all_df.itertuples(), total=all_df.shape[0])]

## 6.3 Combined Prompt-compilation

This prompt uses RAG to improve the translation.  If no suitable exemplars are retrieved, then a "default exemplar" is used to steer behaviour.

In [None]:
def generate_default_exemplar(language):
    if language == "Swedish":
        return f'Translate the following clinical concept into Swedish: "Pain disorder with psychological factor". smärtsyndrom med psykologisk faktor.'
    elif language == "Estonian":
        return f'Translate the following clinical concept into Estonian: "Osseous choristoma". Luuline koristoom.'
    elif language == "Korean":
        return f'Translate the following clinical concept into Korean: "Endoscopic excision of lesion of esophagus". 식도 병변 내시경 절제.'
    elif language == "Dutch":
        return f'Translate the following clinical concept into Dutch: "Open repair of lumbar hernia using biological mesh".  open hernioplastiek van hernia lumbalis met biologisch matje.'

In [None]:
def prepare_rag_aya_prompt(row, G, ref_df):
    sctid, language = row.Index
    concept = G.get_full_concept(sctid)
    preferred_term = concept.fsn.replace(f"({concept.hierarchy})", "").strip()
    parent_concepts = [
        G.get_full_concept(p.sctid) for p in concept.parents
    ]
    parent_data = [
        (
            c.fsn.replace(f"({c.hierarchy})", "").strip(),
            ref_df.loc[(c.sctid, language)].reference_translations[0],
        )
        for c in parent_concepts
        if (c.sctid, language) in ref_df.index
        and ref_df.loc[(c.sctid, language)].reference_translations is not pd.NA
    ] 
    related_concepts = [
        G.get_full_concept(r.tgt.sctid)
        for g in concept.inferred_relationship_groups
        for r in g.relationships
        if r.type in important_attributes
    ]    
    relationship_data = [
        (
            c.fsn.replace(f"({c.hierarchy})", "").strip(),
            ref_df.loc[(c.sctid, language)].reference_translations[0],
        )
        for c in related_concepts
        if (c.sctid, language) in ref_df.index
        and ref_df.loc[(c.sctid, language)].reference_translations is not pd.NA
    ]
    similarity_data = find_similar_(row=row, G=G, ref_df=ref_df)
    default_exemplar = [generate_default_exemplar(language)]
    exemplars = [
        f'Translate the following clinical concept into {language}: "{pt}". {rt}.'
        for pt, rt in chain(parent_data, relationship_data, similarity_data)
    ]
    if exemplars == []:
        exemplars = default_exemplar
    prompt_fragments = list(set(exemplars))
    prompt = '\n'.join(exemplars)
    prompt += f'\nTranslate the following clinical concept into {language}: "{preferred_term}". '    
    return prompt

In [None]:
# You can test the prompt-compilation here
print(prepare_rag_aya_prompt(next(all_df.sample(1).itertuples()), G, ref_df))

In [None]:
all_df["rag_aya_prompt"] = [prepare_rag_aya_prompt(row, G, ref_df) for row in tqdm(all_df.itertuples(), total=all_df.shape[0])]

In [None]:
# Checkpoint the work
all_df.to_csv(TRANSLATIONS_OUTPUT_PATH)

## 6.3 Run the translations

In [None]:
all_df["rag_aya_translation"] = list(translate_with_aya(
    all_df, "rag_aya_prompt", None, rebuild=True, save=False
))

In [None]:
evaluate_translations(all_df, "rag_aya_translation", ignore_case=ignore_case)

## 7. Score the translations at row-level

In [None]:
all_df["rag_aya_translation_scores"] = all_df.apply(
    lambda row: evaluate_translations(row, "rag_aya_translation", ignore_case=ignore_case), 
    axis="columns"
)
tmp_df = all_df.rag_aya_translation_scores.apply(pd.Series)
tmp_df.columns = [f"aya_{c}" for c in tmp_df.columns]
all_df = all_df.drop("rag_aya_translation_scores", axis="columns").join(tmp_df)
del tmp_df

In [None]:
all_df["deepl_translation_scores"] = all_df.apply(
    lambda row: evaluate_translations(row, "deepl_translation", ignore_case=ignore_case), 
    axis="columns"
)
tmp_df = all_df.deepl_translation_scores.apply(pd.Series)
tmp_df.columns = [f"deepl_{c}" for c in tmp_df.columns]
all_df = all_df.drop("deepl_translation_scores", axis="columns").join(tmp_df)
del tmp_df
all_df.sample(3)

In [None]:
# Checkpoint the work
all_df.to_csv(TRANSLATIONS_OUTPUT_PATH)

# 8. Supervised Fine-tuning

T5 uses a relative attention mechanism so we don't need to truncate our sequences.  That said, $seqlen \propto gpumem^2$ so we need to adjust our batch sizes and/or truncate accordingly.

In [None]:
if use_existing_sft_dataset:
    ft_data = (
        Dataset
        .from_pandas(
            all_df
            [all_df.index.get_level_values(1).isin(SFT_LANGS)]
            .reset_index()
        )
    )
    
    def ft_preprocess(example):
        model_inputs = tokenizer(
            example["rag_aya_prompt"], 
            text_target=example["reference_translations"][0],
        )
        return model_inputs
    
    ft_data = ft_data.map(ft_preprocess)
    ft_data.shape

In [None]:
ft_data = ft_data.train_test_split(SFT_TEST_PCT)
ft_data

In [None]:
ft_data.save_to_disk(PATH_TO_SFT_DATASET)

In [None]:
wandb.login()

In [None]:
wandb.init(project="snomed_translation_poc", force=True)

In [None]:
if local_run:
    aya_model = prepare_model_for_kbit_training(aya_model)

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=8,
    bias="none",
)

aya_model = get_peft_model(aya_model, peft_config)

In [None]:
# Fix a bug in Accelerate's device map construction when running on cuda:0
if local_run:
    aya_model.hf_device_map[''] = aya_model.device

In [None]:
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer, 
    model=aya_model, 
    padding="longest", 
    label_pad_token_id=tokenizer.pad_token_id
)

In [None]:
def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

print_trainable_parameters(aya_model)

In [None]:
def compute_metrics(outputs):
    aya_translations, reference_translations = outputs
    if isinstance(aya_translations, tuple):
        aya_translations = aya_translations[0]
    decoded_aya_translations = tokenizer.batch_decode(aya_translations, skip_special_tokens=True)
    reference_translations = np.where(reference_translations != -100, reference_translations, tokenizer.pad_token_id)
    decoded_reference_translations = tokenizer.batch_decode(reference_translations, skip_special_tokens=True)
    levenshtein_result = levenshtein_ratio(predictions=decoded_aya_translations, references=decoded_reference_translations)
    exact_result = exact_match(predictions=decoded_aya_translations, references=decoded_reference_translations)
    return {**levenshtein_result, **exact_result}

In [None]:
# So that we can remove unwanted columns before training
sft_remove_cols = [
    k for k in ft_data["train"].features.keys() 
    if not k in ['input_ids', 'attention_mask', 'labels']
]

N.B. tune the batch size to the available resources and throughput characteristics.

1 works best for fine-tuning with a LoRA adaptor on a 4090 RTX.

In [None]:
trn_data = ft_data["train"].remove_columns(sft_remove_cols)

eval_data = (
    ft_data["test"]
    .remove_columns(sft_remove_cols)
    .shuffle(seed=RANDOM_SEED)
    .select(range(SFT_EVAL_EXAMPLES))
)

training_args = Seq2SeqTrainingArguments(
    output_dir="./checkpoints",
    learning_rate=2e-5,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    weight_decay=0.01,
    save_total_limit=3,
    eval_strategy="steps",
    num_train_epochs=1,
    report_to="wandb",
    predict_with_generate=True,
    remove_unused_columns=False,
    push_to_hub=False,
    load_best_model_at_end=True,
    # Generation config code currently commented out
    # generation_config=generation_config,
)

trainer = Seq2SeqTrainer(
    model=aya_model,
    args=training_args,
    train_dataset=trn_data,
    eval_dataset=eval_data,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

In [None]:
# Saves adaptory only
aya_model.save_pretrained(ADAPTOR_PATH)

# 9. Finetuning evaluation

1. Filter to some examples that were not used for fine-tuning.
2. Compare fine-tuned and vanilla Aya (with full prompting) across the subset.
3. Add the SFT results to the main dataset.

In [None]:
eval_idx = list(zip(ft_data["test"]["sctid"], ft_data["test"]["language"]))
sft_eval_df = all_df.loc[eval_idx].sample(SFT_EVAL_EXAMPLES)
sft_eval_df.shape

In [None]:
sft_eval_df["sft_rag_aya_translation"] = list(translate_with_aya(
    sft_eval_df, "rag_aya_prompt", None, rebuild=True, save=False
))

In [None]:
evaluate_translations(sft_eval_df, "rag_aya_translation", ignore_case=ignore_case)

In [None]:
evaluate_translations(sft_eval_df, "sft_rag_aya_translation", ignore_case=ignore_case)

In [None]:
sft_eval_df["sft_rag_aya_translation_scores"] = sft_eval_df.apply(
    lambda row: evaluate_translations(row, "sft_rag_aya_translation", ignore_case=ignore_case), 
    axis="columns"
)
tmp_df = sft_eval_df.sft_rag_aya_translation_scores.apply(pd.Series)
tmp_df.columns = [f"sft_aya_{c}" for c in tmp_df.columns]
sft_eval_df = sft_eval_df.drop("sft_rag_aya_translation_scores", axis="columns").join(tmp_df)
del tmp_df
sft_eval_df.sample(3)

In [None]:
# Add the final subset of translations
all_df = (
    sft_eval_df
    [['sft_rag_aya_translation', 'sft_aya_exact_match', 'sft_aya_levenshtein_ratio', 'sft_aya_google_bleu', 'sft_aya_cer_score']]
    .join(all_df, how="right")
)

In [None]:
# Checkpoint the work
all_df.to_csv(TRANSLATIONS_OUTPUT_PATH)

# 10. Final Evaluations

We can reload the translations, if we need to.

In [None]:
# all_df = pd.read_csv(TRANSLATIONS_OUTPUT_PATH).set_index(["sctid", "language"])
# all_df.reference_translations = all_df.reference_translations.apply(literal_eval)

First, aggregate the translation results.

In [None]:
trans_grid_df = (
    all_df
    .reset_index()
    .groupby(["hierarchy", "depth_tier", "context_tier", "similarity_tier", "concept_length_bucket", "language"])
    .agg(
        num_translations_tested=("fsn", "size"),
        aya_exact_match=("aya_exact_match", "mean"),
        aya_levenshtein_ratio=("aya_levenshtein_ratio", "mean"),
        aya_google_bleu=("aya_google_bleu", "mean"),
        aya_cer_score=("aya_cer_score", "mean"),
        deepl_exact_match=("deepl_exact_match", "mean"),
        deepl_levenshtein_ratio=("deepl_levenshtein_ratio", "mean"),
        deepl_google_bleu=("deepl_google_bleu", "mean"),
        deepl_cer_score=("deepl_cer_score", "mean"),        
        sft_aya_exact_match=("aya_exact_match", "mean"),
        sft_aya_levenshtein_ratio=("aya_levenshtein_ratio", "mean"),       
        sft_aya_google_bleu=("sft_aya_google_bleu", "mean"),
        sft_aya_cer_score=("sft_aya_cer_score", "mean"),          
    )
)
trans_grid_df.sample(3)

Then, aggregate the reference dataset

In [None]:
ref_grid_df = (
    ref_df
    [ref_df.hierarchy.isin(hierarchies_in_use)]
    .reset_index()
    .groupby(["hierarchy", "depth_tier", "context_tier", "similarity_tier", "concept_length_bucket", "language"])
    .agg(
        num_concepts_in_terminology=("fsn", "size"),
        num_translated_concepts_in_refset=("has_translation", "sum")
    )
)
ref_grid_df["num_untranslated_concepts"] = ref_grid_df.num_concepts_in_terminology - ref_grid_df.num_translated_concepts_in_refset
ref_grid_df.sample(3)

Join the two, and we're done.

In [None]:
grid_df = ref_grid_df.join(trans_grid_df, how="outer")
grid_df.sample(3)

We will be missing some values for the targets where we had no translations to work with

In [None]:
grid_df.apply(pd.isna).sum()

In [None]:
grid_df[grid_df.num_translated_concepts_in_refset > 0].apply(pd.isna).sum()

We can estimate the missing values with a flexible regressor.  

We won't estimate for the fine-tuning examples since we only performed that for a subset of values anyway.

In [None]:
# Define feature columns and target columns
features = ["hierarchy", "depth_tier", "context_tier", "similarity_tier", "concept_length_bucket", "language"]
targets = ["aya_exact_match", "aya_levenshtein_ratio", "aya_google_bleu", "aya_cer_score", "deepl_exact_match", "deepl_levenshtein_ratio", "deepl_google_bleu", "deepl_cer_score"]

# Preprocess the categorical features
preprocessor = ColumnTransformer(
    transformers=[
        ('cat', OneHotEncoder(), features)
    ])

def train_model(df, target):
    X = df[features]
    y = df[target]
    model = Pipeline(steps=[
        ('preprocessor', preprocessor),
        ('classifier', RandomForestRegressor(n_estimators=200))
    ])
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    mse = mean_squared_error(y_test, y_pred)    
    print(f'Mean Squared Error for {target}: {mse}')    
    return model

# Train models for each target variable
models = {}

for target in targets:
    models[target] = train_model(
        grid_df[grid_df.num_translated_concepts_in_refset > 0].reset_index(), 
        target
    )

In [None]:
def interpolate(row):
    if row.num_translated_concepts_in_refset == 0:
        for t in targets:
            X = pd.DataFrame([{f: getattr(row, f) for f in features}])
            y_hat = models[t].predict(X)[0]
            setattr(row, t, y_hat)
    return row
    
interpolated_grid_df = grid_df.reset_index().apply(interpolate, axis="columns")

In [None]:
interpolated_grid_df.sample(3)

In [None]:
interpolated_grid_df.to_csv(GRID_OUTPUT_PATH, index=False)