This demo shows how to perform latent patient disorder or phenotype prediction from trained MIMIC-III models.

In [None]:
import numpy as np
from torch import nn
import torch
from utils.vit import TransformerEncoder
from utils.data_loaders import *
from utils.utils import *
from easydict import EasyDict as edict

In [None]:
train_loader,val_loader,test_loader=get_loaders_pheno()

In [None]:
## load a stored model, could be LSTM, TF or TCN
device=get_device()
print(device)
model=torch.load('./RNN')
model.to(device)


In [None]:
loss_fn = nn.BCELoss().to(device)

In [None]:
from tqdm import tqdm
Train=[]
Train_Labels=[]

for i,data in tqdm(enumerate(train_loader)):
    
    model.eval()
    inputs,label=data
    inputs=inputs.to(torch.float32).to(device)
    label=label.to(torch.float32).to(device)
 
    e=model.emb(inputs)
    Train.append(e)
    Train_Labels.append(label)

Train=torch.vstack(Train)
Train_Labels=torch.vstack(Train_Labels)

In [None]:
Train.shape

In [None]:
from tqdm import tqdm

Val=[]
Val_Labels=[]

for i,data in tqdm(enumerate(val_loader)):

    model.eval()
    inputs,label=data
    inputs=inputs.to(torch.float32).to(device)
    label=label.to(torch.float32).to(device)
 
    e=model.emb(inputs)
    Val.append(e)
    Val_Labels.append(label)

Val=torch.vstack(Val)
Val_Labels=torch.vstack(Val_Labels)

In [None]:
from tqdm import tqdm

Test=[]
Test_Labels=[]

for i,data in tqdm(enumerate(test_loader)):

    model.eval()
    inputs,label=data
    inputs=inputs.to(torch.float32).to(device)
    label=label.to(torch.float32).to(device)
 
    e=model.emb(inputs)
    Test.append(e)
    Test_Labels.append(label)

Test=torch.vstack(Test)
Test_Labels=torch.vstack(Test_Labels)

In [None]:
from utils.Recurrent_Models import Emb_pheno
pheno_model=Emb_pheno(256,25,device)
pheno_model.to(device)
print(pheno_model)
opt = torch.optim.Adam(params=pheno_model.parameters(),lr=0.001)
loss_fn = nn.BCELoss().to(device)
best=0

In [None]:
from torch.utils.data import TensorDataset, DataLoader
train_dataset= TensorDataset(torch.tensor(Train),torch.tensor(Train_Labels))
train_loader = DataLoader(train_dataset, batch_size=64)

val_dataset= TensorDataset(torch.tensor(Val),torch.tensor(Val_Labels))
val_loader = DataLoader(val_dataset, batch_size=64)

test_dataset= TensorDataset(torch.tensor(Test),torch.tensor(Test_Labels))
test_loader = DataLoader(test_dataset, batch_size=64)

In [None]:
import sklearn.metrics

def prediction(model,loader,loss_fn,device):
    P=[]
    L=[]
    model.eval()
    val_loss=0

    for i,batch in enumerate(loader):
        
        data,labels=batch
        data=data.to(torch.float32).to(device)
        labels=labels.to(torch.float32).to(device)

        pred=model(data)
        loss=loss_fn(pred,labels)
        val_loss=val_loss+loss.item()
        P += list(pred.detach().cpu().numpy())
        L += list(labels.detach().cpu().numpy())

    val_loss=val_loss/len(loader)
    L=np.vstack(L)
    P=np.vstack(P)
    roc_test=sklearn.metrics.roc_auc_score(L,P,average='macro')
    return val_loss,roc_test

In [None]:
from tqdm import tqdm
TL=[]
VL=[]
VA=[]

for epoch in range(0,500):
    train_loss=0
    for i,data in enumerate(train_loader):


        model.train()
        opt.zero_grad()
      

        inputs,label=data
        inputs=inputs.to(torch.float32).to(device)
        label=label.to(torch.float32).to(device)
 
        pred=pheno_model(inputs)
        loss=loss_fn(pred,label)

        loss.backward()
 
        opt.step()     
        train_loss=train_loss+loss.detach().cpu()
         
      
    val_loss,auc=prediction(pheno_model,val_loader,loss_fn,device)

    if auc>best:
       best=auc
       torch.save(model,'./pheno')

    print('Epoch : {:.1f} Train Loss {:.4f} Val Loss {:.4f} Val AUROC {:.4f}'.format(epoch,train_loss/len(train_loader),val_loss,auc))  

In [None]:

model=torch.load('./pheno')
loss,auc=prediction(pheno_model,test_loader,loss_fn,device) 
print(auc)