In [1]:
from iol.constants import SIM_COV_DIM, SIM_ACT_DIM, SIM_OUT_DIM
from iol.models import (
    AdaptiveLinearModel,
    BehaviouralCloning,
    BehaviouralCloningLSTM,
)  # noqa: F401
from iol.data_loading import generate_linear_dataset, get_centre_data

import matplotlib.pyplot as plt
import numpy as np
import pickle
import torch

In [2]:
hyperparams = {
    "covariate_size": 63,
    "action_size": 2,
    "outcome_size": 1,
    "memory_hidden_size": 32,
    "memory_layers": 1,
    "memory_dropout": 0,
    "memory_size": 16,
    "outcome_hidden_size": 32,
    "outcome_layers": 1,
    "inf_hidden_size": 16,
    "inf_layers": 1,
    "inf_dropout": 0.5,
    "inf_fc_size": 32,
}

model = AdaptiveLinearModel
# model = BehaviouralCloningLSTM
model = model(**hyperparams)

In [3]:
training_centre = "CTR23901"

validation_data = get_centre_data(training_centre, seq_length=200).get_whole_batch()
model.load_model("analysis")

In [4]:
print(model.validation(validation_data))

{'ACC': tensor(0.9805), 'AUC': 0.9528183025912178, 'APR': 0.7227637024260962, 'NLL': tensor(0.0585, dtype=torch.float64)}


In [5]:
info = model.inspection(validation_data)

In [6]:
info.keys()

dict_keys(['prior_params', 'posterior_params', 'mu_1_prior', 'mu_0_prior', 'mu_1_posterior', 'mu_0_posterior', 'omega_1_prior', 'omega_0_prior', 'omega_1_posterior', 'omega_0_posterior'])

In [7]:
omega1 = info["omega_1_posterior"]
omega1 = omega1.detach().numpy()

omega0 = info["omega_0_posterior"]
omega0 = omega0.detach().numpy()

omega = omega1 - omega0

In [16]:
omega.shape

(571, 200, 63)

In [19]:
omega = torch.tensor(omega.sum(2))

In [21]:
omega[(outcomes > 0)]

tensor([-1.9363, -2.0536, -2.5935,  ..., -1.9652, -2.0109, -2.4495],
       dtype=torch.float64)

In [22]:
omega[(outcomes < 0)]

tensor([-2.2046, -1.6495, -1.6841,  ..., -2.8139, -2.9398, -2.6751],
       dtype=torch.float64)

In [23]:
plt.hist(omega[(outcomes < 0)])

KeyboardInterrupt: 

In [4]:
outcomes = validation_data[2]

In [15]:
outcomes[(outcomes > 0)]

tensor([0.6050, 0.6050, 0.0403,  ..., 0.6050, 0.6050, 0.6050],
       dtype=torch.float64)

In [8]:
outcomes.numpy().shape

(571, 200)

In [9]:
from scipy import stats

In [12]:
mu = info['omega_0_prior']

In [16]:
mu = mu.detach().numpy()

In [20]:
mu.shape

(421, 50, 1)

In [27]:
plt.plot(mu[0,:,0])
plt.plot(mu[0,:,1])
plt.plot(mu[0,:,2])
plt.plot(mu[0,:,3])
plt.show()

In [29]:
with open('data_loading/centres_cleaned/info_dict.pkl','rb') as file:
    info_dict = pickle.load(file)

In [30]:
info_dict.keys()

dict_keys(['patient_columns', 'patient_mean', 'patient_std', 'donor_columns', 'donor_mean', 'donor_std', 'listing_centers'])

In [33]:
info_dict['patient_columns']

['AGE',
 'GENDER',
 'HGT_CM_CALC',
 'WGT_KG_CALC',
 'abo',
 'BMI_CALC',
 'CREAT_TX',
 'INR_TX',
 'TBILI_TX',
 'MELD_PELD_LAB_SCORE',
 'FINAL_SERUM_SODIUM',
 'DIAL_TX',
 'meldstat',
 'status1',
 'ALBUMIN_TX',
 'ASCITES_TX',
 'ETHCAT',
 'FUNC_STAT_TRR',
 'HCV_SEROSTATUS',
 'LIFE_SUP_TRR',
 'ON_VENT_TRR',
 'MED_COND_TRR',
 'PORTAL_VEIN_TRR',
 'PREV_AB_SURG_TRR',
 'NUM_PREV_TX',
 'PREV_TX',
 'diag1',
 'statushcc',
 'HBV_CORE',
 'INIT_AGE',
 'HGT_CM_TCR',
 'INIT_WGT_KG',
 'INIT_BMI_CALC',
 'INIT_ALBUMIN',
 'INIT_ASCITES',
 'INIT_SERUM_CREAT',
 'INIT_DIALYSIS_PRIOR_WEEK',
 'INIT_INR',
 'INIT_MELD_PELD_LAB_SCORE',
 'INIT_SERUM_SODIUM',
 'FUNC_STAT_TCR',
 'LIFE_SUP_TCR',
 'PORTAL_VEIN_TCR',
 'PREV_AB_SURG_TCR']