In [1]:
import sys
sys.path.append("/data1/andrew/meng/mixehr/meng/VAE-EHR/src")
import pandas as pd

import torch
from torch import nn, optim

In [2]:
import vae

from vae import VAE, VAETrainer
from icd_analysis_helper import ICDAnalysisHelper
from visualizer_helper import Visualizer

from config_reader import Config
from vae import PatientICDDataset

In [None]:
###Load Configuration file
config = Config('./config.ini')
config.__dict__

In [None]:
###Load Data
patient_icd_df=pd.read_csv(config.patient_icd_path, sep=' ')
patient_icd_data = patient_icd_df.drop('SUBJECT_ID', axis=1)
data = torch.tensor(patient_icd_data.values).float()
print(data.shape)

icd9codes = pd.read_csv(config.icd9codes_path)
icd_analyzer = ICDAnalysisHelper(icd9codes_df = icd9codes, patient_icd_df = patient_icd_df)
#icd_analyzer.lookup_icds(icd9codes, ["4019", "41401"])

visualizer = Visualizer()

In [None]:
###Load Model
print("Feature_dim: {}".format(data.shape[1]))
model = VAE(
    feature_dim = data.shape[1], 
    encoder_dim = config.encoder_dim,
    latent_dim = config.latent_dim,
    decoder_dim = config.decoder_dim,
    use_relu= config.use_relu
)

optimizer = optim.Adam(model.parameters(), lr=0.001)
print(type(optimizer))

In [None]:
###Load Trainer
experiment_name=config.experiment_name
trainer = VAETrainer(
    model=model, 
    optimizer=optimizer,
    experiment_name=experiment_name,
    kld_beta=config.kld_beta
)

In [None]:
###Train Model
trainer.train(
    data=data, 
    epochs=80,
    batch_size=40,
    save_model_interval=5,
    clip_gradients=False
)

In [None]:
###Load pre-trained model
epoch = 40
model.load_state_dict(torch.load("./VAE_exp_{}_epoch_{}.pkl".format(experiment_name, epoch)))
trainer.model = model

In [None]:
###Encode data
latent, means, var = trainer.encode_data(data)

In [None]:
###Get UMAP representations
X_umap = visualizer.umap_embedding(latent.cpu().detach().numpy())
X_umap_means = visualizer.umap_embedding(means.cpu().detach().numpy())
X_umap_vars = visualizer.umap_embedding(var.cpu().detach().numpy())

In [None]:
###Plot UMAP representations
heart_keywords = ['heart', 'atrial', 'coronary', 'hypertension', 'vascular']
heart_patient_idxs = icd_analyzer.get_patients_idxs_with_disease_keywords(substrings=heart_keywords, case_sensitive=False)
heart_colors = np.array([0 for i in range(X_umap.shape[0])])
heart_colors[heart_patient_idxs] = 100

visualizer.plot2d(
    X=X_umap, 
    filename="Patient_Clusters_exp_{}_epoch_{}{}".format(experiment_name, epoch, "_heart_umap"), 
    colors=heart_colors,
)
visualizer.plot2d(
    X=X_umap_means, 
    filename="Patient_Clusters_exp_{}_epoch_{}{}".format(experiment_name, epoch, "_heart_umap_means"), 
    colors=heart_colors,
)
visualizer.plot2d(
    X=X_umap_vars, 
    filename="Patient_Clusters_exp_{}_epoch_{}{}".format(experiment_name, epoch, "_heart_umap_vars"), 
    colors=heart_colors,
)

In [None]:
###Additional Visualizations
baby_keywords = ['congenital', 'infant', 'newborn', 'neonatal', 'born', 'birth']
baby_patient_idxs = icd_analyzer.get_patients_idxs_with_disease_keywords(substrings=baby_keywords, case_sensitive=False)

baby_heart_colors = np.array([0 for i in range(X_umap.shape[0])])
baby_heart_colors[heart_patient_idxs] = 100
baby_heart_colors[baby_patient_idxs] = 50

visualizer.plot2d(X_umap, "Patient_Clusters_exp_{}_epoch_{}{}".format(experiment_name, epoch, "_baby_heart_umap"), colors=baby_heart_colors)

In [None]:
X_umap_3d = visualizer.umap_embedding(latent.cpu().detach().numpy(), n_components=3)

visualizer.plot3d(
    X=X_umap_3d, 
    filename="Patient_Clusters_exp_{}_epoch_{}{}".format(experiment_name, epoch, "_heart_umap_3D"), 
    colors=baby_heart_colors,
)