In [1]:
import os
import pickle
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt


from tqdm.auto import tqdm
from itertools import product
from edist.sed import standard_sed
from collections import defaultdict

ModuleNotFoundError: No module named 'edist'

In [None]:
def get_metrics(results):
    correct = []
    edit_distances = []
    normalised_edit_distances = []
    for true_target, prediction in zip(
        results["validation"]["targets"], results["validation"]["predictions"]
    ):
        correct.append(true_target == prediction)
        edit_distances.append(standard_sed(prediction[1:-1], true_target[1:-1]))
        normalised_edit_distances.append(
            standard_sed(prediction[1:-1], true_target[1:-1]) / len(true_target[1:-1])
        )

    wer = 100 * (1 - np.mean(correct).item())
    edit_distance = np.mean(edit_distances).item()
    normalised_edit_distance = 100 * np.mean(normalised_edit_distances).item()

    return wer, edit_distance, normalised_edit_distance

In [None]:
language_predictions = defaultdict(list)
best_validation_ned = defaultdict(lambda: 100.)
trials_per_lang = defaultdict(int)

for prediction_file in tqdm(list(sorted(os.listdir("./predictions/")))):
    parameters = {
        key: value for key, value
        in [entry.split("=") for entry in prediction_file.split(".")[0].split("-")]
    }

    if parameters["model"] != "interpretable":
        continue

    language = parameters["language"]
    trials_per_lang[language] += 1

    with open(os.path.join("./predictions", prediction_file), 'rb') as pf:
        predictions = pickle.load(pf)
        wer, edit_distance, normalised_edit_distance = get_metrics(predictions)

        if normalised_edit_distance < best_validation_ned[language]:
            best_validation_ned[language] = normalised_edit_distance
            language_predictions[language] = predictions["test"]["predictions"]

  0%|          | 0/265 [00:00<?, ?it/s]

In [None]:
import regex

kanji_regex = regex.compile(r'\p{IsHan}', regex.UNICODE)

def rekanjify(lemma, form):
    kanji_indices = [i for i, char in enumerate(form) if char == "K"]
    rekanjified_form = list(form)
    kanjis = [char for char in lemma if regex.match(kanji_regex, char)]

    for kanji, index in zip(kanjis, kanji_indices):
        rekanjified_form[index] = kanji

    return "".join(rekanjified_form)

In [None]:
os.makedirs("./submission", exist_ok=True)

for language, predictions in language_predictions.items():
    test_file = os.path.join("./data", f"{language}.covered.tst")
    prediction_file = os.path.join("./submission", f"{language}.pred")
    with open(test_file) as tf:
        with open(prediction_file, "w") as pf:
            for line, prediction in zip(tf, predictions):
                lemma, tags = line.strip().split("\t")
                prediction = "".join(prediction[1:-1])
                if language == "jap":
                    prediction = rekanjify(lemma, prediction)

                pf.write(f"{lemma}\t{tags}\t{prediction}\n")