# Sammenlikning av flere modeller

In [1]:
import base64
from pathlib import Path

import matplotlib.pyplot as plt
import mlflow
import pandas as pd
from dotenv import load_dotenv
from ipywidgets import interact
from IPython.display import display, Markdown, HTML

import samisk_ocr.trocr
from samisk_ocr.metrics import compute_cer, compute_wer, SpecialCharacterF1

## Last miljøvariabler

In [2]:
load_dotenv("../../.env")

True

## Last prediksjoner fra MLFLow

In [3]:
def load_results(run_name: str, iteration: int) -> pd.DataFrame:
    config = samisk_ocr.trocr.config.Config()

    mlflow.set_tracking_uri(config.mlflow_url)
    mlflow.set_experiment("TrOCR trocr-base-printed finetuning")
    
    # Specify what model we want to load
    run_info = mlflow.search_runs(filter_string=f"run_name = '{run_name}'").squeeze()
    run_id = run_info["run_id"]
    
    artifact_path = f"predictions/{iteration:08d}.json"
    predictions = mlflow.artifacts.load_dict(f"runs:/{run_id}/{artifact_path}")

    metadata = pd.read_csv(config.DATA_PATH / "metadata.csv")
    
    output_folder = Path("output") / f"{run_name}_{Path(artifact_path).stem}"
    output_folder.mkdir(parents=True, exist_ok=True)

    results = pd.merge(
        pd.DataFrame(predictions),
        metadata.query("file_name.str.startswith('val')"),
        left_on=["urn", "page", "line"],
        right_on=["urn", "page", "line"],
        validate="1:1",
    )
    return (
        results.drop(columns=["text", "xmin", "xmax", "ymin", "ymax", "width", "height", "line", "page", "text_len"])
        .assign(
            cer=results.apply(lambda row: compute_cer(row["true"], row["predictions"]), axis=1),
            wer=results.apply(lambda row: compute_wer(row["true"], row["predictions"]), axis=1),
            casefolded_cer=results.apply(lambda row: compute_cer(row["true"].casefold(), row["predictions"].casefold()), axis=1),
            casefolded_wer=results.apply(lambda row: compute_wer(row["true"].casefold(), row["predictions"].casefold()), axis=1),
        )
    )

In [4]:
runs = [
    ("nebulous-sponge-430", 149370, "GTSamisk"),
    ("marvelous-fish-697", 153495, "GTSamisk og GTNorsk"),
]
for run_name, run_iteration, description in runs:
    results = load_results(run_name, run_iteration)

    concat_pred_cer = "".join(results["predictions"])
    concat_true_cer = "".join(results["true"])
    concat_pred_wer = " ".join(results["predictions"])
    concat_true_wer = " ".join(results["true"])
    
    concat_cer = compute_cer(concat_true_cer, concat_pred_cer)
    concat_wer = compute_wer(concat_true_wer, concat_pred_wer)
    
    print(f"{description}: {run_name} - {run_iteration}")
    print(f"Concat CER: {concat_cer:.2%}")
    print(f"Concat WER: {concat_wer:.2%}")
    print()
    

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

GTSamisk: nebulous-sponge-430 - 149370
Concat CER: 2.14%
Concat WER: 9.87%



Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

GTSamisk og GTNorsk: marvelous-fish-697 - 153495
Concat CER: 2.10%
Concat WER: 9.65%

