# TODO

- Use the preferred terms rather than the first synonyms

In [1]:
!pip install transformers
!pip install deepl
!pip install tqdm
!pip install evaluate
!pip install termcolor
!pip install Levenshtein
!pip install nltk
!pip install cer
!pip install accelerate
!pip install wandb



In [2]:
# import os
# os.environ['HF_HOME'] = '/home/ec2-user/SageMaker/cache/'

In [1]:
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, BitsAndBytesConfig
from snomed_graph import *
import getpass
import deepl
from tqdm.notebook import tqdm
import json
import numpy as np
import evaluate
from termcolor import colored
from collections import namedtuple
from operator import __or__
from functools import reduce
from ast import literal_eval
from Levenshtein import ratio
from itertools import chain

In [2]:
AYA_CHECKPOINT = "CohereForAI/aya-101"
PATH_TO_SERIALIZED_SNOMED_GRAPH = "./data/snomed_graph/full_concept_graph.gml"
PATH_TO_TRANSLATION_SAMPLES = "./data/prepared_translation_data/samples.csv"
PATH_TO_ALL_TRANSLATION_REFERENCES = "./data/prepared_translation_data/all_translations.csv"
PATH_TO_DEEPL_TRANSLATION_RESULTS = "./data/cache/deepl_results.json"
PATH_TO_AYA_VANILLA_TRANSLATION_RESULTS = "./data/cache/aya_results_vanilla.json"
PATH_TO_AYA_ENRICHED_TRANSLATION_RESULTS = "./data/cache/aya_results_enriched.json"
ALL_OUTPUT_PATH = "./data/translation_outputs/all_translations.csv"
CT1_OUTPUT_PATH = "./data/translation_outputs/ct1_translations.csv"
CT2_OUTPUT_PATH = "./data/translation_outputs/ct2_translations.csv"
SIM_OUTPUT_PATH = "./data/translation_outputs/sim_translations.csv"

In [3]:
ignore_case = True

In [6]:
DEEPL_AUTH_KEY = getpass.getpass()

 ········


In [4]:
langcodes = {
    "Dutch": "NL",
    "Estonian": "ET",
    "Korean": "KO",
    "Swedish": "SV",
}

In [5]:
important_attributes = {
    # 'Access (attribute)',
    # 'After (attribute)',
    'Associated finding (attribute)',
    'Associated morphology (attribute)',
    'Associated procedure (attribute)',
    'Associated with (attribute)',
    'Before (attribute)',
    'Causative agent (attribute)',
    'Characterizes (attribute)',
    # 'Clinical course (attribute)',
    'Component (attribute)',
    'Direct device (attribute)',
    'Direct morphology (attribute)',
    'Direct site (attribute)',
    'Direct substance (attribute)',
    'Due to (attribute)',
    'During (attribute)',
    # 'Finding context (attribute)',
    'Finding informer (attribute)',
    'Finding method (attribute)',
    'Finding site (attribute)',
    'Has absorbability (attribute)',
    'Has active ingredient (attribute)',
    'Has basic dose form (attribute)',
    'Has basis of strength substance (attribute)',
    'Has coating material (attribute)',
    'Has compositional material (attribute)',
    'Has concentration strength denominator unit (attribute)',
    'Has concentration strength numerator unit (attribute)',
    'Has device intended site (attribute)',
    'Has disposition (attribute)',
    'Has dose form administration method (attribute)',
    'Has dose form intended site (attribute)',
    'Has dose form release characteristic (attribute)',
    'Has dose form transformation (attribute)',
    'Has filling (attribute)',
    'Has focus (attribute)',
    'Has ingredient qualitative strength (attribute)',
    'Has intent (attribute)',
    # 'Has interpretation (attribute)',
    'Has manufactured dose form (attribute)',
    'Has precise active ingredient (attribute)',
    'Has presentation strength denominator unit (attribute)',
    'Has presentation strength numerator unit (attribute)',
    'Has realization (attribute)',
    'Has specimen (attribute)',
    'Has state of matter (attribute)',
    'Has surface texture (attribute)',
    'Has target population (attribute)',
    'Has unit of presentation (attribute)',
    'Indirect device (attribute)',
    'Indirect morphology (attribute)',
    'Inherent location (attribute)',
    'Inheres in (attribute)',
    'Interprets (attribute)',
    # 'Is a (attribute)',
    'Is modification of (attribute)',
    'Is sterile (attribute)',
    'Laterality (attribute)',
    'Measurement method (attribute)',
    'Method (attribute)',
    'Occurrence (attribute)',
    'Pathological process (attribute)',
    'Plays role (attribute)',
    'Precondition (attribute)',
    'Priority (attribute)',
    'Procedure context (attribute)',
    'Procedure device (attribute)',
    'Procedure morphology (attribute)',
    'Procedure site (attribute)',
    'Procedure site - Direct (attribute)',
    'Procedure site - Indirect (attribute)',
    'Process acts on (attribute)',
    'Process duration (attribute)',
    'Process extends to (attribute)',
    'Process output (attribute)',
    'Property (attribute)',
    'Recipient category (attribute)',
    'Relative to (attribute)',
    'Relative to part of (attribute)',
    'Revision status (attribute)',
    'Route of administration (attribute)',
    # 'Scale type (attribute)',
    # 'Severity (attribute)',
    'Specimen procedure (attribute)',
    'Specimen source identity (attribute)',
    'Specimen source morphology (attribute)',
    'Specimen source topography (attribute)',
    'Specimen substance (attribute)',
    # 'Subject relationship context (attribute)',
    'Surgical approach (attribute)',
    'Technique (attribute)',
    # 'Temporal context (attribute)',
    # 'Temporally related to (attribute)',
    # 'Time aspect (attribute)',
    # 'Units (attribute)',
    'Using access device (attribute)',
    'Using device (attribute)',
    'Using energy (attribute)',
    'Using substance (attribute)'
}

# 1. Load the data

## 1.1 Load the concepts to translate

In [6]:
# Columns are: sctid, fsn, hierarchy, language, context_tier, depth_tier, translations
all_df = (
    pd.read_csv(PATH_TO_TRANSLATION_SAMPLES)
    .set_index(["sctid", "language"])
)

all_df.reference_translations = all_df.reference_translations.apply(literal_eval)

all_df.shape[0]

12640

## 1.2 Load the full set of reference translations

In [7]:
# Columns are: sctid, fsn, hierarchy, language, context_tier, depth_tier, translations
ref_df = (
    pd.read_csv(PATH_TO_ALL_TRANSLATION_REFERENCES)
    .set_index(["sctid", "language"])
)

ref_df.translations = ref_df.translations.apply(literal_eval)

ref_df = ref_df.rename(axis="columns", mapper={"translations": "reference_translations"})

ref_df.shape[0]

651355

## 1.3 Load the SNOMED graph object

In [8]:
G = SnomedGraph.from_serialized(PATH_TO_SERIALIZED_SNOMED_GRAPH)

SNOMED graph has 361179 vertices and 1179749 edges


# 2. Evaluation Harness

In [9]:
# Google BLEU
# Max of precision, recall of all ngrams (of 1-4 tokens)
# Higher is better
# https://huggingface.co/spaces/evaluate-metric/google_bleu
google_bleu = evaluate.load("google_bleu")

# CharacTER
# Roughly: min # char edits required to match pred to ref, normalized by pred len
# Lower is better
# https://huggingface.co/spaces/evaluate-metric/character
character = evaluate.load("character")

In [10]:
def exact_match(predictions, references):
    N = len(predictions)
    n = 0
    for p, r in zip(predictions, references):
        if p in r:
            n += 1
    return {'exact_match': float(n)/N}

In [11]:
# Levenshtein Ratio
# 1 - [Levenshtein Dist] / [Sum of lengths]
# Higher is better
# https://rapidfuzz.github.io/Levenshtein/levenshtein.html#ratio
def levenshtein_ratio(predictions, references):
    ratios = [
        np.max([ratio(p, r) for r in refs])
        for p, refs in zip(predictions, references)
    ]
    return {'levenshtein_ratio': np.mean(ratios)}

In [12]:
def evaluate_translations(row_or_df, target_column, ignore_case):
    if isinstance(row_or_df, pd.DataFrame):
        assert target_column in row_or_df.columns    
        candidates = list(row_or_df.to_dict()[target_column].values())
        references = row_or_df.reference_translations.tolist()
    else:
        candidates = [getattr(row_or_df, target_column)]
        references = [row_or_df.reference_translations]
    if ignore_case:
        candidates = [c.lower() for c in candidates]
        references = [r.lower() for r in references]
    results = [
        exact_match(predictions=candidates, references=references),
        levenshtein_ratio(predictions=candidates, references=references),
        google_bleu.compute(predictions=candidates, references=references),
        character.compute(predictions=candidates, references=references),
    ]
    results = reduce(__or__, results, dict())
    return results

# 3. Generate baseline translations with DeepL

In [None]:
translator = deepl.Translator(DEEPL_AUTH_KEY)

def translate_with_deepl(df, G):
    
    with open(PATH_TO_DEEPL_TRANSLATION_RESULTS, "r") as f:
        deepl_results = json.load(f)
    
    for it, row in enumerate(tqdm(df.itertuples(), total=df.shape[0])):
        sctid, language = row.Index
        langcode = langcodes[language]
        source_concept = G.get_concept_details(sctid)
        source_preferred_term = source_concept.fsn.replace(f"({source_concept.hierarchy})", "").strip()
        key = str(sctid) + "_" + language
        try:
            yield deepl_results[key]
        except KeyError:
            deepl_result = translator.translate_text(source_preferred_term, target_lang=langcode)
            deepl_results[key] = deepl_result.text
            yield deepl_result.text
        if it % 100 == 0:
            with open(PATH_TO_DEEPL_TRANSLATION_RESULTS, "w") as f:
                json.dump(deepl_results, f)

    with open(PATH_TO_DEEPL_TRANSLATION_RESULTS, "w") as f:
        json.dump(deepl_results, f)

In [None]:
all_df["deepl_translation"] = list(translate_with_deepl(all_df, G))

In [17]:
evaluate_translations(all_df, "deepl_translation")

{'exact_match': 0.08662974683544304,
 'levenshtein_ratio': 0.7370298864393613,
 'google_bleu': 0.2459922409495806,
 'cer_score': 0.381689058931287}

# 4. Generate translations using "vanilla" Aya model.

In [14]:
tokenizer = AutoTokenizer.from_pretrained(AYA_CHECKPOINT)

## 4.1 Load Aya on a local machine

N.B. T5-derivative models (like Aya) do not yet support FlashAttention

In [15]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

aya_model = AutoModelForSeq2SeqLM.from_pretrained(
    AYA_CHECKPOINT, 
    device_map="cuda", 
    quantization_config=bnb_config,
)

Loading checkpoint shards:   0%|          | 0/11 [00:00<?, ?it/s]

## 4.2 Load Aya on a multi-GPU set-up

In [16]:
aya_model = AutoModelForSeq2SeqLM.from_pretrained(AYA_CHECKPOINT, device_map="auto")

## 4.3 Standard wrapper functions for processing with Aya

In [16]:
def aya_postprocessor(result):
    return (
        result
        .replace(tokenizer.eos_token, "")
        .replace(tokenizer.pad_token, "")
        .replace(".", "")
        .strip()
    )

In [17]:
def translate_with_aya(df, G, prompt_assembler, ref_df=None, results_filepath=None, rebuild=False, save=False):

    if rebuild:
        results = dict()
    else:
        with open(results_filepath, "r") as f:
            results = json.load(f)
    
    for row in tqdm(df.itertuples(), total=df.shape[0]):
        sctid, language = row.Index
        key = str(sctid) + "_" + language
        try:
            yield results[key]
        except KeyError:
            prompt = prompt_assembler(row, G, ref_df)
            input = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
            output = aya_model.generate(input, max_new_tokens=256)
            result = tokenizer.decode(output[0])
            result = aya_postprocessor(result)
            results[key] = result
            yield result

    if save:
        with open(results_filepath, "w") as f:
            json.dump(results, f)

## 4.4 Test Aya with a few translations into English.

In [18]:
def test_translate_with_aya(df, G):
    for row in tqdm(df.itertuples(), total=df.shape[0]):
        sctid, language = row.Index
        preferred_term = row.reference_translations[0]
        reference_translations = G.get_concept_details(sctid).synonyms
        # ICL
        try:
            icl_row = next(df[(df.index.get_level_values(0) != sctid) & (df.index.get_level_values(1) == language)].sample(1).itertuples())
        except ValueError:
            pass
        else:
            icl_sctid = icl_row.Index[0]
            icl_preferred_term = icl_row.reference_translations[0]
            icl_reference_translations = G.get_concept_details(icl_sctid).synonyms
            # construct prompt
            prompt_template = 'Translate the following clinical concept into English: "{{PREFERRED_TERM}}". {{TRANSLATED_TERM}}.\n'
            prompt = (
                prompt_template.replace("{{PREFERRED_TERM}}", icl_preferred_term).replace("{{TRANSLATED_TERM}}", icl_reference_translations[0]) +
                prompt_template.replace("{{PREFERRED_TERM}}", preferred_term).replace("{{TRANSLATED_TERM}}.\n", "")
            )
            print(prompt)
            input = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
            output = aya_model.generate(input, max_new_tokens=256)
            result = tokenizer.decode(output[0])
            result = aya_postprocessor(result)
            
            print(
                colored("\nSCTID: ", "red", attrs=['bold']),
                sctid,
                colored("\nSource Language: ", "red", attrs=['bold']),
                language,
                colored("\nPreferred Term: ", "red", attrs=['bold']),
                preferred_term,
                colored("\nReference Translations: ", "red", attrs=['bold']),
                reference_translations,
                colored("\nAya Translation: ", "red", attrs=['bold']),            
                result,
                "\n\n",
                # colored("\nAya Scores: ", "red", attrs=['bold']),
                # ", ".join([k+": "+str(v) for k,v in scores.items()]),
            )

In [19]:
test_translate_with_aya(all_df.sample(10), G)

  0%|          | 0/10 [00:00<?, ?it/s]

Translate the following clinical concept into English: "Kartsinomatoos". Carcinomatosis.
Translate the following clinical concept into English: "Briljantroheline". 
[1m[31m
SCTID: [0m 396057002 [1m[31m
Source Language: [0m Estonian [1m[31m
Preferred Term: [0m Briljantroheline [1m[31m
Reference Translations: [0m ['Brilliant green'] [1m[31m
Aya Translation: [0m Brillant green 


Translate the following clinical concept into English: "beriberi bij zuigeling". Infantile beriberi.
Translate the following clinical concept into English: "structuur van iliacale arterie en/of femorale arterie". 
[1m[31m
SCTID: [0m 299716001 [1m[31m
Source Language: [0m Dutch [1m[31m
Preferred Term: [0m structuur van iliacale arterie en/of femorale arterie [1m[31m
Reference Translations: [0m ['Iliac and femoral artery structures', 'Iliac and/or femoral artery structures'] [1m[31m
Aya Translation: [0m Iliac and/or femoral arteries 


Translate the following clinical concept into Engl

## 4.5 Translate from English into our target languages

In [22]:
def prepare_aya_vanilla_prompt(row, G, df):
    sctid, language = row.Index
    concept = G.get_concept_details(sctid)
    preferred_term = concept.fsn.replace(f"({concept.hierarchy})", "").strip()
    if language == "Swedish":
        return f'Translate the following clinical concept into Swedish: "Pain disorder with psychological factor". smärtsyndrom med psykologisk faktor.\nTranslate the following clinical concept into Swedish: "{preferred_term}". '
    elif language == "Estonian":
        return f'Translate the following clinical concept into Estonian: "Osseous choristoma". Luuline koristoom. \nTranslate the following clinical concept into Estonian: "{preferred_term}". '
    elif language == "Korean":
        return f'Translate the following clinical concept into Korean: "Endoscopic excision of lesion of esophagus". 식도 병변 내시경 절제. \nTranslate the following clinical concept into Korean: "{preferred_term}". '
    elif language == "Dutch":
        return f'Translate the following clinical concept into Dutch: "Open repair of lumbar hernia using biological mesh".  open hernioplastiek van hernia lumbalis met biologisch matje.\nTranslate the following clinical concept into Dutch: "{preferred_term}". '
    else:
        raise ValueError()

In [24]:
all_df["aya_vanilla_translation"] = list(translate_with_aya(
    all_df, G, prepare_aya_vanilla_prompt, PATH_TO_AYA_VANILLA_TRANSLATION_RESULTS, rebuild=True, save=False
))

  0%|          | 0/12640 [00:00<?, ?it/s]

In [25]:
evaluate_translations(all_df, "aya_vanilla_translation")

{'exact_match': 0.03235759493670886,
 'levenshtein_ratio': 0.608165798751417,
 'google_bleu': 0.1350622406639004,
 'cer_score': 0.5526435448522717}

In [33]:
all_df.to_csv(ALL_OUTPUT_PATH)

# 5. Evaluate Aya with enriched prompt

In [None]:
enriched_prompt_template = """
You are a medical translation expert.
Your job is to translate formal clinical terms found within the SNOMED Concept Terminology into {{TARGET_LANGUAGE}}.
The concept you need to translate is “{{PREFERRED_TERM}}”.
Here is some information about the concept which may help you:
{{SYNONYMS_FRAGMENT}}
{{HIERARCHY_FRAGMENT}}
{{PARENTS_FRAGMENT}}
{{RELATIONSHIPS_FRAGMENT}}
Now, the translation of “{{PREFERRED_TERM}}” into {{TARGET_LANGUAGE}} is:"""

In [None]:
def generate_prompt_synonyms_fragment(preferred_term, synonyms):
    if len(synonyms) == 0:
        return ""
    else:
        syn_str = '"' + '" and "'.join(synonyms) + '"'
        return f'In English, synonyms for "{preferred_term}" include: {syn_str}.'

In [None]:
def generate_prompt_hierarchy_fragment(preferred_term, hierarchy):
    return f'"{preferred_term}" is a {hierarchy}.'

In [None]:
def generate_prompt_parents_fragment(preferred_term, parents):
    fragment = ""
    for p in parents:
        fragment += f'"{preferred_term}" is a kind of {p.synonyms[0]}.\n'
    return fragment

In [None]:
def generate_prompt_relationships_fragment(preferred_term, relationship_groups):
    fragment = ""
    for g in relationship_groups:
        for r in g.relationships:
            type = r.type.replace(" (attribute)", "").lower()
            tgt = r.tgt.synonyms[0]
            fragment += f'"{preferred_term}" has {type} {tgt}\n'
    return fragment

In [None]:
def prepare_aya_enriched_prompt(row, G, df):
    sctid, language = row.Index
    concept = G.get_full_concept(sctid)
    preferred_term = concept.synonyms[0]
    return (
        enriched_prompt_template
        .replace("{{TARGET_LANGUAGE}}", language)
        .replace("{{PREFERRED_TERM}}", preferred_term)
        .replace("{{SYNONYMS_FRAGMENT}}", generate_prompt_synonyms_fragment(preferred_term, concept.synonyms[1:]))
        .replace("{{HIERARCHY_FRAGMENT}}", generate_prompt_hierarchy_fragment(preferred_term, concept.hierarchy))
        .replace("{{PARENTS_FRAGMENT}}", generate_prompt_parents_fragment(preferred_term, concept.parents))
        .replace("{{RELATIONSHIPS_FRAGMENT}}", generate_prompt_relationships_fragment(preferred_term, concept.inferred_relationship_groups))
        .replace("\n\n", "\n")
    )

In [None]:
print(prepare_aya_enriched_prompt(next(all_df.itertuples()), G, df))

In [None]:
df["aya_enriched_translation"] = list(translate_with_aya(all_df, G, prepare_aya_enriched_prompt, PATH_TO_AYA_ENRICHED_TRANSLATION_RESULTS, rebuild=True))

In [None]:
evaluate_translations(all_df, "aya_enriched_translation")

# 6. Translate Context Tier 1 Concepts

In [29]:
def prepare_aya_ct1_prompt(row, G, ref_df):
    sctid, language = row.Index
    concept = G.get_full_concept(sctid)
    preferred_term = concept.fsn.replace(f"({concept.hierarchy})", "").strip()
    parent_concepts = [
        G.get_full_concept(p.sctid) for p in concept.parents
    ]
    parent_data = [
        (
            c.fsn.replace(f"({c.hierarchy})", "").strip(),
            ref_df.loc[(c.sctid, language)].reference_translations[0],
        )
        for c in parent_concepts
    ]
    prompt_fragments = [
        f'Translate the following clinical concept into {language}: "{pt}". {rt}.'
        for pt, rt in parent_data
    ]    
    prompt = '\n'.join(prompt_fragments)
    prompt += f'\nTranslate the following clinical concept into {language}: "{preferred_term}". '
    
    return prompt

In [30]:
ct1_df = all_df[all_df.context_tier.isin(["Tier 1", "Tier 2"])]
ct1_df.shape[0]

8657

In [31]:
ct1_df["aya_ct1_translation"] = list(translate_with_aya(ct1_df, G, prepare_aya_ct1_prompt, ref_df, None, rebuild=True, save=False))

  0%|          | 0/8657 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [29]:
evaluate_translations(ct1_df, "aya_ct1_translation")

{'exact_match': 0.13734550075083748,
 'levenshtein_ratio': 0.728187395794738,
 'google_bleu': 0.319951960733121,
 'cer_score': 0.4275629513922285}

In [45]:
ct1_df.to_csv(CT1_OUTPUT_PATH)

# 7. Translate all Context Tier 2 Concepts

In [32]:
ct2_df = ct1_df[ct1_df.context_tier == "Tier 2"]
ct2_df.shape[0]

3385

In [33]:
def prepare_aya_ct2_prompt(row, G, ref_df):
    sctid, language = row.Index
    concept = G.get_full_concept(sctid)
    preferred_term = concept.fsn.replace(f"({concept.hierarchy})", "").strip()
    parent_concepts = [
        G.get_full_concept(p.sctid) for p in concept.parents
    ]
    parent_data = [
        (
            c.fsn.replace(f"({c.hierarchy})", "").strip(),
            ref_df.loc[(c.sctid, language)].reference_translations[0],
        )
        for c in parent_concepts
    ] 
    related_concepts = [
        G.get_full_concept(r.tgt.sctid)
        for g in concept.inferred_relationship_groups
        for r in g.relationships
        if r.type in important_attributes        
    ]    
    relationship_data = [
        (
            c.fsn.replace(f"({c.hierarchy})", "").strip(),
            ref_df.loc[(c.sctid, language)].reference_translations[0],
        )
        for c in related_concepts
    ]
    prompt_fragments = [
        f'Translate the following clinical concept into {language}: "{pt}". {rt}.'
        for pt, rt in chain(parent_data, relationship_data)
    ]
    prompt = '\n'.join(prompt_fragments)
    prompt += f'\nTranslate the following clinical concept into {language}: "{preferred_term}". '
    
    return prompt

In [35]:
print(prepare_aya_ct2_prompt(next(ct2_df.sample(1).itertuples()), G, ref_df))

Translate the following clinical concept into Swedish: "Finding related to ability to use language". fynd relaterat till förmågan att använda språk.
Translate the following clinical concept into Swedish: "Does use verbal communication". kommunicerar verbalt.
Translate the following clinical concept into Swedish: "Ability to use language". förmåga att använda språket.
Translate the following clinical concept into Swedish: "Does use language". 


In [36]:
ct2_df["aya_ct2_translation"] = list(translate_with_aya(
    ct2_df, G, prepare_aya_ct2_prompt, ref_df, None, rebuild=True, save=False
))

  0%|          | 0/3385 [00:00<?, ?it/s]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  ct2_df["aya_ct2_translation"] = list(translate_with_aya(ct2_df, G, prepare_aya_ct2_prompt, ref_df, None, rebuild=True, save=False))


In [None]:
evaluate_translations(ct2_df, "aya_ct2_translation", ignore_case)

In [37]:
ct2_df.to_csv(CT2_OUTPUT_PATH)

# 8. Enhancing RAG with ngram lookups

In [22]:
from sklearn.feature_extraction.text import CountVectorizer

def generate_similarity_search_keys(df):
    languages = df.index.get_level_values(1).unique()
    keys = dict()
    for l in languages:
        docs = [
            (row.Index[0], row.fsn.replace(f"({row.hierarchy})", "").strip())
            for row in df[df.index.get_level_values(1) == l].itertuples()
        ]
        values, terms = list(zip(*docs))
        vectorizer = CountVectorizer(lowercase=True, stop_words=None, ngram_range=(2,10), binary=True)
        key_matrix = vectorizer.fit_transform(terms)
        keys[l] = (vectorizer, key_matrix, values)
    return keys

In [23]:
keys = generate_similarity_search_keys(ref_df)

In [24]:
def find_similar(keys, row, G, ref_df, k=3, min_score=2, remove_children=True):
    sctid = row.Index[0]
    language = row.Index[1]
    vectorizer, key_matrix, values = keys[language]
    term = row.fsn.replace(f"({row.hierarchy})", "").strip()
    query = vectorizer.transform([term])
    search = key_matrix.dot(query.T).A.ravel()
    top_k = np.argpartition(-search, k+1)[0:k+1]
    scores = search[top_k]
    top_k = top_k[scores >= min_score]
    results = set(np.array(values)[top_k])
    if remove_children:
        children = {c.sctid for c in G.get_descendants(sctid)}
    else:
        children = set()
    results = results - {sctid} - children
    if results != set():
        concepts = [G.get_concept_details(r) for r in results]
        preferred_terms = [c.fsn.replace(f"({c.hierarchy})", "").strip() for c in concepts]
        reference_translations = [ref_df.loc[r, language].reference_translations[0] for r in results]
        return list(zip(preferred_terms, reference_translations))
    else:
        return list()

In [25]:
# Match function signature for the other prompt compiling functions
from functools import partial

find_similar_ = partial(find_similar, keys=keys)

In [26]:
for row in all_df[all_df.index.get_level_values(1) == "Dutch"].sample(10).itertuples():
    print(row.fsn)
    print(find_similar_(row=row, G=G, ref_df=ref_df))
    print("\n")

Bromide salt (substance)
[]


Finding of food and drink intake (finding)
[('Food and drink intake', 'intake van voedsel en drinken'), ('Recommendation to change food and drink intake', 'aanbevelen om inname van eten en drinken aan te passen'), ('Positioning subject for food and drink intake', 'zorgafnemer in houding voor intake van voedsel en drinken plaatsen')]


Prothrombin time within target range (finding)
[]


Mixed nerve conduction study (procedure)
[('Sensory nerve conduction study', 'sensibel geleidingsonderzoek'), ('Nerve conduction study', 'zenuwgeleidingsonderzoek'), ('Finding of mixed nerve conduction pattern', 'bevinding betreffende geleidingspatroon van gemengde zenuw')]


Intercondylar T/Y fracture (morphologic abnormality)
[]


Malignant tumor, small cell type (morphologic abnormality)
[('Malignant tumor, giant cell type', 'grootcellige maligne tumor'), ('Malignant tumor, fusiform cell type', 'spoelcellige maligne tumor'), ('Malignant tumor, clear cell type', "maligne n

In [55]:
from copy import deepcopy

# sim_df = deepcopy(all_df)

# sim_df["similarity_tier"] = [
#     "Tier 1" if find_similar(row, G, keys, ref_df) != [] else "All Translations" 
#     for row in tqdm(sim_df.itertuples(), total=sim_df.shape[0])
# ]

sim_df = sim_df[sim_df.similarity_tier == "Tier 1"]

(
    sim_df
    .reset_index()
    .groupby(["language"])
    .size()
    .rename("cnt")
    .reset_index()
)

Unnamed: 0,language,cnt
0,Dutch,1901
1,Estonian,878
2,Korean,1085
3,Swedish,1925


## Translations using similarity results only (no context tier augmentation)

In [27]:
def prepare_aya_similarity_prompt(row, G, ref_df):
    sctid, language = row.Index
    concept = G.get_full_concept(sctid)
    preferred_term = concept.fsn.replace(f"({concept.hierarchy})", "").strip()
    results = find_similar_(row=row, G=G, ref_df=ref_df)
    prompt_fragments = [
        f'Translate the following clinical concept into {language}: "{pt}". {rt}.'
        for pt, rt in results
    ]
    prompt = '\n'.join(prompt_fragments)
    prompt += f'\nTranslate the following clinical concept into {language}: "{preferred_term}". '
    
    return prompt

In [29]:
print(prepare_aya_similarity_prompt(next(sim_df.sample(1).itertuples()), G, ref_df))

Translate the following clinical concept into Swedish: "Biopsy of joint structure of shoulder". biopsi från ledstruktur i axel.
Translate the following clinical concept into Swedish: "Excisional biopsy of joint structure of shoulder". excisionsbiopsi från ledstruktur i axel.
Translate the following clinical concept into Swedish: "Percutaneous fine needle aspiration biopsy of joint structure of shoulder using imaging guidance". bildvägledd perkutan finnålsbiopsi från axelled.
Translate the following clinical concept into Swedish: "Joint structure of shoulder region". 


In [30]:
sim_df["aya_similarity_translation"] = list(translate_with_aya(
    sim_df, G, prepare_aya_similarity_prompt, ref_df, None, rebuild=True, save=False
))

  0%|          | 0/5789 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [41]:
evaluate_translations(ct2_df, "deepl_translation", ignore_case)

{'exact_match': 0.1967503692762186,
 'levenshtein_ratio': 0.08223037944555804,
 'google_bleu': 0.06679764243614932,
 'cer_score': 0.6876011922594623}

In [42]:
evaluate_translations(ct2_df, "aya_vanilla_translation", ignore_case)

{'exact_match': 0.08980797636632201,
 'levenshtein_ratio': 0.08760223122820525,
 'google_bleu': 0.04098139660285791,
 'cer_score': 0.7679424916650236}

In [45]:
# evaluate_translations(ct2_df, "aya_ct1_translation", ignore_case)

In [46]:
evaluate_translations(ct2_df, "aya_ct2_translation", ignore_case)

{'exact_match': 0.2838995568685377,
 'levenshtein_ratio': 0.08972250166950493,
 'google_bleu': 0.08574728824877723,
 'cer_score': 0.6900741462809425}

In [76]:
sim_df.to_csv(SIM_OUTPUT_PATH)

# 9. SFT

## 9.1 Generate fine-tuning dataset

In [19]:
from datasets import Dataset

def generate_sft_prompt(example):
    preferred_term = example["fsn"].replace(f"({example['hierarchy']})", "").strip()
    prompt = f"Translate into {example['language']}: {preferred_term}."
    return {"prompt": prompt}
    
dataset = Dataset.from_pandas(all_df)
dataset = dataset.map(generate_sft_prompt)
dataset = dataset.train_test_split(test_size=0.5)

Map:   0%|          | 0/12640 [00:00<?, ? examples/s]

In [None]:
wandb.login()

In [None]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=aya_model)

In [69]:
# LoraConfig "task_type" parameter seems to have been deprecated
# LoraConfig "target_modules" parameter currently not specified
# Currently performing full precision training (no use of prepare_model_for_kbit_training)
# LoRA r-rate, alpha and dropout taken from: https://github.com/georgian-io/LLM-Finetuning-Hub (see script params in README)

from peft import LoraConfig, PeftModel

peft_config = LoraConfig(
    lora_alpha=32,
    lora_dropout=0.1,
    r=16,
    bias="none",
)

In [None]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

training_args = Seq2SeqTrainingArguments(
    output_dir="./checkpoints",
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=2,
    predict_with_generate=True,
    fp16=True,
    report_to="wandb",
    push_to_hub=True,
)

trainer = Seq2SeqTrainer(
    model=aya_model,
    args=training_args,
    peft_config=peft_config,
    # train_dataset=tokenized_books["train"],
    # eval_dataset=tokenized_books["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

# 10. Final Evaluations

In [20]:
all_df = pd.read_csv(ALL_OUTPUT_PATH).set_index(["sctid", "language"])
ct1_df = pd.read_csv(CT1_OUTPUT_PATH).set_index(["sctid", "language"])
ct2_df = pd.read_csv(CT2_OUTPUT_PATH).set_index(["sctid", "language"])
sim_df = pd.read_csv(SIM_OUTPUT_PATH).set_index(["sctid", "language"])

## 10.1 Sampling Translations

In [47]:
for row in ct2_df.sample(10).itertuples():
    sctid, language = row.Index
    preferred_term = G.get_concept_details(sctid).synonyms[0]
    deepl_results = evaluate_translations(row, "deepl_translation", ignore_case=ignore_case)
    vanilla_aya_results = evaluate_translations(row, "aya_vanilla_translation", ignore_case=ignore_case)
    ct1_aya_results = evaluate_translations(row, "aya_ct1_translation", ignore_case=ignore_case)
    ct2_aya_results = evaluate_translations(row, "aya_ct2_translation", ignore_case=ignore_case)
    print(
        colored("\nSCTID: ", "red", attrs=['bold']),
        sctid,
        colored("\nTarget Language: ", "red", attrs=['bold']),
        language,
        colored("\nEnglish Preferred Term: ", "red", attrs=['bold']),
        preferred_term,
        colored("\nReference Translations: ", "red", attrs=['bold']),
        row.reference_translations,
        colored("\nDeepL Translation: ", "red", attrs=['bold']),
        row.deepl_translation,
        colored("\nDeepL Scores: ", "red", attrs=['bold']),
        ", ".join([k+": "+str(v) for k,v in deepl_results.items()]),
        colored("\nVanilla Aya Translation: ", "red", attrs=['bold']),
        row.aya_vanilla_translation,
        colored("\nVanilla Aya Scores: ", "red", attrs=['bold']),
        ", ".join([k+": "+str(v) for k,v in vanilla_aya_results.items()]),        
        colored("\nAya CT1 Translation: ", "red", attrs=['bold']),
        row.aya_ct1_translation,
        colored("\nAya CT1 Scores: ", "red", attrs=['bold']),
        ", ".join([k+": "+str(v) for k,v in ct1_aya_results.items()]),          
        colored("\nAya CT2 Translation: ", "red", attrs=['bold']),
        row.aya_ct2_translation,
        colored("\nAya CT2 Scores: ", "red", attrs=['bold']),
        ", ".join([k+": "+str(v) for k,v in ct2_aya_results.items()]),                  
    )

[1m[31m
SCTID: [0m 298777007 [1m[31m
Target Language: [0m Swedish [1m[31m
English Preferred Term: [0m Increased active range of shoulder flexion [1m[31m
Reference Translations: [0m ['ökat rörelseomfång i skuldra vid aktiv flexion'] [1m[31m
DeepL Translation: [0m Ökat aktivt omfång av axelflexion [1m[31m
DeepL Scores: [0m exact_match: 0.0, levenshtein_ratio: 0.05882352941176472, google_bleu: 0.0, cer_score: 0.8181818181818182 [1m[31m
Vanilla Aya Translation: [0m Ökat aktivt skulderflexionsintervall [1m[31m
Vanilla Aya Scores: [0m exact_match: 0.0, levenshtein_ratio: 0.05405405405405406, google_bleu: 0.0, cer_score: 1.0 [1m[31m
Aya CT1 Translation: [0m Ökat rörelseomfång vid aktiv flexion i axelled [1m[31m
Aya CT1 Scores: [0m exact_match: 0.0, levenshtein_ratio: 0.04255319148936165, google_bleu: 0.16666666666666666, cer_score: 0.45652173913043476 [1m[31m
Aya CT2 Translation: [0m Ökat rörelseomfång vid aktiv flexion i axelled [1m[31m
Aya CT2 Scores: [0m

In [77]:
row = next(ct2_df[ct2_df.index.get_level_values(0) == 449723004].itertuples())
print(prepare_aya_ct2_prompt(row, G, ref_df))

NameError: name 'prepare_aya_ct2_prompt' is not defined