# Inference example for a model predicting creatinine values

## Imports and utility functions

In [None]:
import json

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader

from strats_pytorch.datasets.dataset_regression import (MIMIC_Reg,
                                                        padded_collate_fn)
from strats_pytorch.models.strats import STraTS
from strats_pytorch.utils import denorm

## Model Initialization

In [None]:
exp_n = 52
exp_path = f"exp_creat_reg/exp_{exp_n}/"

model = STraTS(
    n_var_embs=206,
    dim_demog=2,
    dropout=0.0,
    n_layers=2,
    dim_embed=102,
    n_heads=3,
    forecasting=False,
    regression=True,
)
model.load_state_dict(torch.load(exp_path + "STraTS.pth"))

## Dataset and Dataloader Initialization

In [None]:
test_ds = MIMIC_Reg(
    data_path="generated/top_206_culled_reg.csv",
)
test_ds.restrict_to_indexes(np.load(exp_path + "test_idx.npy"))

# Variable and Time normalization
test_ds.normalize(normalize_vars=True, normalize_times=True, verbose=True)


In [None]:
test_dl = DataLoader(
    test_ds, 
    batch_size=1, 
    collate_fn=padded_collate_fn, 
    shuffle=False
)

## Example time series

In [None]:
# Get a single time series, of index stay_ind
# Note: we need the dataloader to generate the mask
stay_ind = 2
for ind, (demog, values, times, variables, tgt_val, tgt_time, mask) in enumerate(
    test_dl
):
    if ind == stay_ind:
        break

In [None]:
demog_unbatched = demog[0]

gender = "Male" if demog_unbatched[0] == -1 else "Female"
age = denorm(demog_unbatched[1], test_ds.age_mean, test_ds.age_std)
print(f"Demog Info: {gender}, {age} y.o.")

### Plotting each variable

In [None]:
values_unbatched = values[0]
times_unbatched = times[0]
variables_unbatched = variables[0]

# Collect the values and times per variable for plotting
vals_per_var = {}
times_per_var = {}
for i, v in enumerate(variables_unbatched):
    v = v.item()
    # init if needed
    if v not in vals_per_var.keys():
        vals_per_var[v] = []
        times_per_var[v] = []

    vals_per_var[v].append(
        denorm(values_unbatched[i], test_ds.means[v], test_ds.stds[v])
    )
    times_per_var[v].append(
        denorm(times_unbatched[i], test_ds.time_mean, test_ds.time_std)
    )

In [None]:
# ind_to_var = {
#     0: "Creatinine (serum)",
#     1: "Heart rate",
#     2: "BP systolic",
#     3: "BP diastolic",
#     4: "BP mean",
#     5: "Temp F",
#     6: "weight Daily",
#     7: "weight Admisison",
#     8: "White Blood Cell Count",
#     9: "Sodium (serum)",
#     10: "Potassium (serum)",
#     11: "Arterial pH",
#     12: "Respiratory rate",
#     13: "Apnea interval",
#     14: "Minute volume",
#     15: "Central Venous Pressure",
#     16: "O2 fraction",
#     17: "Blood Flow (dialysis)",
#     18: "Blood Urea Nitrogen",
#     19: "Platelet Count",
#     20: "Lactic acid",
#     21: "SPO2",
#     22: "Hemoglobin",
#     23: "Albumin",
#     24: "Anion gap",
#     25: "Prothrombin time",
#     26: "Arterial 02 pressure",
#     27: "Height (cm)",
#     28: "Glucose (serum)",
# }

with open("generated/keys_206.json") as f:
    itemid_to_ind: dict[str, int] = json.load(f)

df_labels = pd.read_csv("generated/mimic_stats.csv")

def get_name_from_ind(ind):
    for itemid, i in itemid_to_ind.items():
        if i==ind:
            break
    return df_labels[df_labels["id"]==int(itemid)]["label"].item()

plt.figure(figsize=(36, 30))
for i, var_id in enumerate(times_per_var.keys()):
    # Note: plotting 206 variables is not readable, this is more of an 
    # example on how to get the original name from a variable id
    # when using top-down variables
    plt.subplot(15, 14, i+1)
    plt.plot(times_per_var[var_id], vals_per_var[var_id], "x-")
    # plt.ylabel(ind_to_var[var_id])
    plt.ylabel(get_name_from_ind(var_id))

plt.show()

## Model prediction

In [None]:
pred = model(demog, values, times, variables, mask)
print(
    f"Pred value: {denorm(pred.item(), test_ds.means[test_ds.var_id], test_ds.stds[test_ds.var_id])}; Ground Truth:{denorm(tgt_val.item(), test_ds.means[test_ds.var_id], test_ds.stds[test_ds.var_id])}"
)