In [None]:
from pathlib import Path

import pandas as pd

base_path = Path.cwd().parent / "data"

In [None]:
from medminer.utils.data import Document

df = pd.read_excel(base_path / "eval" / "freetext_vormedikation_subset.xlsx")

docs = []
for _, row in df.iterrows():
    docs.append(Document(row["PID"], row["textvalue"]))

In [None]:
from medminer.utils.models import DefaultModel
from medminer.task.medication import MedicationTask

task = MedicationTask(DefaultModel().model, base_dir=base_path / "examples" / "results")

for doc in docs:
    task.run(doc.content)

In [None]:
pred_df = pd.read_csv(base_path / "examples" / "results" / task._settings.get("session_id") / "medication.csv")
pred_df

In [None]:
# pred_df = pd.read_csv(base_path / "examples" / "results" / "1d66d073f8f345f693d48eda5bc2485c" / "medication.csv")

In [None]:
true_df = pd.read_excel(base_path / "eval" / "medication.xlsx")
true_df = df[["PID"]].merge(true_df, left_on="PID", right_on="patient_id")
true_df = true_df.drop(columns=["PID", "Unnamed: 15", "Unnamed: 16"])
true_df["atc_id"] = true_df["atc_id"].astype(str)
true_df.loc[true_df["atc_id"].isin(["nan", "0"]), "atc_id"] = None

In [None]:
mapping = {}
for patient_id, _true_df in true_df.groupby("patient_id"):
    _pred_df = pred_df[pred_df["patient_id"] == patient_id]    
    _mapping = {
        idx: None
        for idx in _pred_df.index
    }

    def _get_unmapped(mapping):
        return (
            list(set(_true_df.index.values.tolist()) - set(mapping.values())),
            list(set(_pred_df.index.values.tolist()) - {k for k, v in mapping.items() if v is not None})
        )

    pred_atc = _pred_df["atc_id"].str[:5].dropna().drop_duplicates(keep=False)
    for true_idx, atc_id in _true_df["atc_id"].str[:5].dropna().drop_duplicates(keep=False).items():
        pred_idx = pred_atc[pred_atc.str.startswith(atc_id)].index.values

        if len(pred_idx) == 1:
            _mapping[pred_idx[0]] = true_idx

    _not_mapped_true, _not_mapped_pred = _get_unmapped(_mapping)
    for true_idx, true_name in _true_df.loc[_not_mapped_true, "medication_name"].str.strip().str.lower().items():
        pred_idx = _pred_df.loc[_not_mapped_pred][
            _pred_df.loc[_not_mapped_pred, "medication_name"].str.strip().str.lower() == true_name
        ].index.values

        if len(pred_idx) == 1:
            _mapping[pred_idx[0]] = true_idx
            _not_mapped_true, _not_mapped_pred = _get_unmapped(_mapping)

    _not_mapped_true, _not_mapped_pred = _get_unmapped(_mapping)
    if _not_mapped_true and _not_mapped_pred:
        print("true\n", _true_df.loc[_not_mapped_true])
        print("pred\n", _pred_df.loc[_not_mapped_pred])

    mapping.update(_mapping)

In [None]:
pred_df_copy = pred_df.copy()
pred_df_copy.index = pred_df.index.map(mapping)
pred_df_copy

In [None]:
res_df = true_df.join(pred_df_copy, lsuffix="_true", rsuffix="_pred", how="outer")
res_df["patient_id"] = res_df["patient_id_true"].fillna(res_df["patient_id_pred"])
res_df = res_df.sort_values(["patient_id", "atc_id_true", "medication_name_true"]).reset_index(drop=True)
res_df
