In [3]:
from collections import defaultdict
from pathlib import Path

import pandas as pd
import seaborn as sns
import torch
from omegaconf import OmegaConf
import numpy as np
import matplotlib.pyplot as plt
from rich.progress import track

import automatic_medical_coding.metrics as metrics
from automatic_medical_coding.settings import TARGET_COLUMN, ID_COLUMN
from data_analysis.wandb_utils import get_runs, get_run


mimic_clean = pd.read_feather("/data/je/mimiciii/pre-processed/clean/mimiciii_clean.feather")
mimic_clean_splits = pd.read_feather("/data/je/mimiciii/pre-processed/clean/mimiciii_clean.feather")
code_columns = ["icd9_diag" ,"icd9_proc"]


ModuleNotFoundError: No module named 'automatic_medical_coding'

Defaulting to user installation because normal site-packages is not writeable
Collecting rich
  Downloading rich-12.6.0-py3-none-any.whl (237 kB)
[K     |████████████████████████████████| 237 kB 5.7 MB/s eta 0:00:01
Collecting commonmark<0.10.0,>=0.9.0
  Downloading commonmark-0.9.1-py2.py3-none-any.whl (51 kB)
[K     |████████████████████████████████| 51 kB 15.2 MB/s eta 0:00:01
[?25hInstalling collected packages: commonmark, rich
Successfully installed commonmark-0.9.1 rich-12.6.0


In [None]:
PROJECT = "joakim_edin/automatic-medical-coding"
SWEEP_ID = "rzzkucfv"
DATASET = "mimiciii_clean"

EXPERIMENT_DIR = Path("/data/je/experiments/")
MODEL_NAMES = {"PLMICD": "PLM-ICD", "VanillaCNN": "CNN", "VanillaRNN": "Bi-GRU"}





def one_hot(
    targets: list[list[str]], number_of_classes: int, target2index: dict[str, int]
) -> torch.Tensor:
    output_tensor = torch.zeros((len(targets), number_of_classes))
    for idx, case in enumerate(targets):
        for target in case:
            if target in target2index:
                output_tensor[idx, target2index[target]] = 1
    return output_tensor.long()


def load_predictions(run_id: str) -> tuple[dict[str, torch.Tensor], str, int]:
    predictions = pd.read_feather(EXPERIMENT_DIR / run_id / "predictions_val.feather")

    predictions[TARGET_COLUMN] = predictions[TARGET_COLUMN].apply(
        lambda x: x.tolist()
    )  # convert from numpy array to list
    targets = predictions[[TARGET_COLUMN]]
    unique_targets = list(set.union(*targets[TARGET_COLUMN].apply(set)))
    logits = predictions[unique_targets]
    target2index = {target: idx for idx, target in enumerate(unique_targets)}
    number_of_classes = len(target2index)

    # Mapping from target to index and vice versa
    targets_torch = one_hot(
        targets[TARGET_COLUMN].to_list(), number_of_classes, target2index
    )  # one_hot encoding of targets
    logits_torch = torch.tensor(logits.values)
    cases = {"logits": logits_torch, "targets": targets_torch}
    return cases, number_of_classes