In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import pickle
import time
import math

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
device = torch.device(1)
torch.cuda.set_device(device)
torch.cuda.current_device()

torch.set_default_dtype(torch.float32)

In [None]:
def pad_rawdata(T, Y, ind_kf, ind_kt, X, meds_on_grid):
    N = np.shape(T)[0] #num in batch

    num_meds = np.shape(meds_on_grid[0])[1] 
    num_covs = np.shape(covs)[1]
    
    T_lens = np.array([len(t) for t in T])  
    T_maxlen = np.max(T_lens)              
    T_pad = np.zeros((N,T_maxlen))
    
    Y_lens = np.array([len(y) for y in Y])
    Y_maxlen = np.max(Y_lens) 
    Y_pad = np.zeros((N,Y_maxlen))
    ind_kf_pad = np.zeros((N,Y_maxlen))
    ind_kt_pad = np.zeros((N,Y_maxlen))
    
    grid_lens = np.array([np.shape(m)[0] for m in meds_on_grid])  
    grid_maxlen = np.max(grid_lens)                               
    meds_pad = np.zeros((N,grid_maxlen,num_meds))
    X_pad = np.zeros((N,grid_maxlen))
    
    for i in range(N):
        T_pad[i,:T_lens[i]] = T[i]
        Y_pad[i,:Y_lens[i]] = Y[i]
        ind_kf_pad[i,:Y_lens[i]] = ind_kf[i]
        ind_kt_pad[i,:Y_lens[i]] = ind_kt[i]
        X_pad[i,:grid_lens[i]] = X[i]
        meds_pad[i,:grid_lens[i],:] = meds_on_grid[i]
                    
    return T_pad, Y_pad, \
           ind_kf_pad, ind_kt_pad, \
           X_pad, meds_pad

In [None]:
def OU_kernel(length,x1,x2):
    x1 = torch.reshape(x1, (-1, 1)) 
    x2 = torch.reshape(x2, (1, -1)) 
    K = torch.exp(-torch.abs(x1 - x2)/length)
    return K

    
def SE_kernel(length,x1,x2):
    x1 = torch.reshape(x1, (-1, 1)) 
    x2 = torch.reshape(x2, (1, -1)) 
    K = torch.exp(-torch.pow(x1 - x2, 2.0) / length)
    return K

def gather_nd(K, ind):
    ind = ind.type(torch.long)
    return K[list(ind.T)].T

def log_sum_exp(value, dim=None, keepdim=False):
    if dim is not None:
        m, _ = torch.max(value, dim=dim, keepdim=True)
        value0 = value - m
        if keepdim is False:
            m = m.squeeze(dim)
        return m + torch.log(torch.sum(torch.exp(value0),
                                       dim=dim, keepdim=keepdim))
    else:
        m = torch.max(value)
        sum_exp = torch.sum(torch.exp(value - m))
        return m + torch.log(sum_exp)


def get_probs_and_accuracy(preds, O):
    all_probs = torch.exp(preds[:,1] - log_sum_exp(preds, dim = 1)) #normalize; and drop a dim so only prob of positive case
    N = preds.shape[0] / n_mc_smps #actual number of observations in preds, collapsing MC samples                    
    
    probs = torch.zeros([0], device=device) #store all samples in a list, then concat into tensor at end
    
    i = 0
    
    while i < N:
        probs = torch.cat([probs, torch.tensor([torch.mean(all_probs[i*n_mc_smps : i*n_mc_smps+n_mc_smps])], device=device)], 0)
        i += 1
        
    correct_pred = torch.eq(torch.gt(probs, 0.5).type(torch.uint8) , O) 
    accuracy = torch.mean((correct_pred.type(torch.float32))) 
    return probs, accuracy

def CG(A, b):
    """ Conjugate gradient, to get solution x = A^-1 * b,
    can be faster than using the Cholesky for large scale problems
    """
    b = torch.reshape(b, (-1,))
    n = A.shape[0]
    x = torch.zeros((n,), device=device) 
    r = b 
    p = r
    
    CG_EPS = n / 1000.0   
    MAX_ITER = n / 250 + 3
    
    i = 0
    
    while i < MAX_ITER and torch.norm(r) > CG_EPS:
        p_vec = torch.reshape(p, (-1, 1))
        Ap = torch.reshape(torch.mm(A, p_vec), (-1,))
        alpha = torch.dot(r,r)/torch.dot(p, Ap)
        x = x + alpha * p
        r2 = r - alpha * Ap
        beta = torch.dot(r2, r2)/torch.dot(r, r)
        r = r2
        p = r + beta * p
        i += 1
    
    return torch.reshape(x,(-1,1))

def block_CG(A_,B_):
    """
    block version of CG. Get solution to matrix equation AX = B, ie
    X = A^-1 * B. Will be much faster than Cholesky for large-scale problems.
    """
    n = B_.shape[0]   
    m = B_.shape[1]   
    

    X = torch.zeros((n,m), device=device)
    V_ = torch.zeros((n,m), device=device)
    R = B_
    R_ = B_

    CG_EPS = n / 1000.0 
    MAX_ITER = n / 250 + 3
    
    i = 0
    while i < MAX_ITER and torch.norm(R) > CG_EPS:
        S = torch.solve(torch.mm(torch.transpose(R, 0, 1), R), 
                        torch.mm(torch.transpose(R_, 0, 1), R_))[0]
        V = R + torch.mm(V_,S)
        T = torch.solve(torch.mm(torch.transpose(R, 0, 1), R),
                        torch.mm(torch.transpose(V, 0, 1), torch.mm(A_, V)))[0]
        X = X + torch.mm(V,T)
        V_ = V
        R_ = R
        R = R - torch.mm(A_,torch.mm(V,T))
        i += 1
    
    return X

def Lanczos(Sigma_func,b):
    n = b.shape[0]
    k = n / 500 + 3 

    betas = torch.zeros(1, device=device)
    alphas = torch.zeros(0, device=device)
    D = torch.zeros((n, 1), device=device)
    
    b_norm = torch.norm(b)
    D = torch.cat((D, torch.reshape(b / b_norm, (-1,1))), 1)
    

    j = 1
    while j < k + 1:
        d_j = D[:, j:j+1]
        d = Sigma_func(d_j) - betas[-1] * D[:, j-1:j]
        alphas = torch.cat((alphas, [torch.dot(d_j, d)]), 0)
        d = d - alphas[-1] * d_j
        betas = torch.cat((betas, [torch.norm(d)]), 0)
        D = torch.cat((D, d/betes[j:j+1]), 1)
        j += 1
        
        
    betas_ = torch.diag(betas[1:k])
    D_ = D[:,1:k+1]
    
    H = torch.diag(alphas) + F.pad(betas_, (0,1,1,0)) + F.pad(betas_, (1,0,0,1))
    
    e,v = torch.symeig(H, eigenvectors=True)   
    e_pos = torch.max(0.0, e) + 1e-6        
    e_sqrt = torch.diag(torch.sqrt(e_pos))
    sq_H = torch.mm(v, torch.mm(e_sqrt, torch.transpose(v, 0, 1)))

    out = b_norm * torch.mm(D_, sq_H) 
    return out[:, 0:1] 


def block_Lanczos(Sigma_func,B_,n_mc_smps):
    """
    block Lanczos method to approx Sigma^1/2 * B, with B matrix of N(0,1)'s.
    Used to generate multiple approximate large normal draws.
    
    """
    n = B_.shape[0] 
    s = n_mc_smps 
    k = int(n / 500 + 3) 
    
    betas = torch.zeros((1,s), device=device)
    alphas = torch.zeros((0,s), device=device)
    D = torch.zeros((s,n,1), device=device)
    
    B_norms = torch.norm(B_, dim=0)
    D = torch.cat((D, torch.unsqueeze(torch.transpose(B_/B_norms, 0, 1), 2)), 2)
    
    
    j = 1
    while j < k + 1:
        d_j = torch.squeeze( D[:,:,j:j+1] )
        d = Sigma_func(torch.transpose(d_j, 0, 1)) - betas[j-1:j, :]* \
                torch.transpose(torch.squeeze( D[:,:,j-1:j] ), 0, 1)
        alphas = torch.cat((alphas, torch.diagonal(torch.mm(d_j,d)).unsqueeze(0)), 0)
        d = d - alphas[j-1:j, :] * torch.transpose(d_j, 0, 1)
        betas = torch.cat((betas, torch.norm(d, dim=0).unsqueeze(0)), 0)
        D = torch.cat((D, torch.transpose(d/ betas[j:j+1, :], 0, 1).unsqueeze(2)), 2)
        j += 1
    

    D_ =  D[:,:,1:1+k] 
    
    H = torch.zeros((0, k, k), device=device)
    
    for ss in range(s):
        this_beta = torch.diag(torch.squeeze(  betas[1:k, ss:ss+1]))
        this_H = (torch.diag(torch.squeeze( alphas[:, ss:ss+1] )) +
                  F.pad(this_beta, (0, 1, 1, 0)) +
                   F.pad(this_beta, (1, 0, 0, 1)))
        H = torch.cat((H, this_H.unsqueeze(0)),0)    
    
    E, V = torch.symeig(H, eigenvectors=True) # !!!different from 'torch.eig'
    E_sqrt = torch.zeros((0,k,k), device=device)

    for ss in range(s): 
        E_sqrt = torch.cat((E_sqrt, torch.diag(torch.squeeze(torch.sqrt(torch.max( E[ss:ss+1, :] , 1e-6*torch.ones_like(E[ss:ss+1, :], device=device) )))).unsqueeze(0)), 0)
    
    
    sq_H = torch.matmul(V, torch.matmul(E_sqrt, V.permute(0, 2, 1)))
        
    e1 = torch.transpose(torch.eye(k, device=device)[:,0:1].repeat(1, s), 0, 1).unsqueeze(2)  # !different from torch.eye(k)[:,0], one is vertical, one is horizontal
    
    
    out = B_norms * torch.transpose(torch.squeeze(torch.matmul(D_, torch.matmul(sq_H, e1))), 0, 1)
    return out

def draw_GP(Yi, Ti, Xi, ind_kfi, ind_kti, length, noises, Kf, n_mc_smps):
    """ 
    given GP hyperparams and data values at observation times, draw from 
    conditional GP

    inputs:
        length,noises,Lf,Kf: GP params
        Yi: observation values
        Ti: observation times
        Xi: grid points (new times for rnn)
        ind_kfi,ind_kti: indices into Y
    returns:
        draws from the GP at the evenly spaced grid times Xi, given hyperparams and data
    """  
    ny = Yi.shape[0]
    K_tt = OU_kernel(length, Ti, Ti)
    
    D = torch.diag(noises)
        
    grid_f = torch.meshgrid(ind_kfi, ind_kfi) 
    grid_f = (grid_f[0].T, grid_f[1].T)
    
    Kf_big = gather_nd(Kf, torch.stack((grid_f[0],grid_f[1]),-1))

    grid_t = torch.meshgrid(ind_kti, ind_kti) 
    grid_t = (grid_t[0].T, grid_t[1].T)
    Kt_big = gather_nd(K_tt, torch.stack((grid_t[0],grid_t[1]),-1))

    Kf_Ktt = torch.mul(Kf_big,Kt_big)

    DI_big = gather_nd(D,torch.stack((grid_f[0],grid_f[1]),-1))
    DI = torch.diag(torch.diagonal(DI_big, dim1=-2, dim2=-1)) 

    
    #data covariance. 
    #Either need to take Cholesky of this or use CG / block CG for matrix-vector products
    Ky = Kf_Ktt + DI + 1e-6 * torch.eye(ny, device = device)

    ### build out cross-covariances and covariance at grid
    nx = Xi.shape[0]

    K_xx = OU_kernel(length,Xi,Xi)
    K_xt = OU_kernel(length,Xi,Ti)

    ind = torch.cat([ torch.tensor([i], device=device).repeat([nx]) for i in range(M)], 0)
    grid = torch.meshgrid(ind, ind) #indexing=xy,
    grid = (grid[0].T, grid[1].T)
    Kf_big = gather_nd(Kf, torch.stack((grid[0],grid[1]), -1))
    ind2 = torch.arange(0, nx, device=device).repeat([M])
    grid2 = torch.meshgrid(ind2, ind2) #indexing=xy,
    grid2 = (grid2[0].T, grid2[1].T)
    Kxx_big = gather_nd(K_xx, torch.stack((grid2[0], grid2[1]),-1))

    K_ff = torch.mul(Kf_big,Kxx_big)       

    full_f = torch.cat([ torch.tensor([i], device=device).repeat([nx]) for i in range(M)],0)
    grid_1 = torch.meshgrid(full_f,ind_kfi)  #indexing=ij,
    Kf_big = gather_nd(Kf, torch.stack((grid_1[0],grid_1[1]),-1))
    full_x = torch.arange(0, nx, device=device).repeat([M]).type(torch.long)

    grid_2 = torch.meshgrid(full_x,ind_kti) #indexing=ij,
    Kxt_big = gather_nd(K_xt, torch.stack((grid_2[0],grid_2[1]),-1))

    K_fy = torch.mul(Kf_big, Kxt_big)
    
    
    #now get draws!
    y_ = torch.reshape(Yi, (-1,1))


    Mu = torch.matmul(K_fy,CG(Ky,y_)) #May be faster with CG for large problems
    Ly = torch.cholesky(Ky) # Compute Ly uses greatly time >> Mu
#     Mu = torch.mm(K_fy, torch.cholesky_solve(y_, Ly))  # in tensorflow: tf.cholesky_solve(Ly, y_)
    
    xi = torch.normal(mean=0, std=1.0, size=(nx*M, n_mc_smps), device=device)
    Sigma = K_ff - torch.mm(K_fy, torch.cholesky_solve(torch.transpose(K_fy, 0, 1), Ly)) + 1e-6 * torch.eye(K_ff.shape[0], device=device)

    '''
    #Never need to explicitly compute Sigma! Just need matrix products with Sigma in Lanczos algorithm
    def Sigma_mul(vec):
        # vec must be a 2d tensor, shape (?,?) 
        return torch.mm(K_ff, vec) - torch.mm(K_fy,block_CG(Ky,torch.mm(torch.transpose(K_fy, 0, 1),vec))) 

    def small_draw():   
        return Mu + torch.mm(torch.cholesky(Sigma),xi)
    def large_draw():             
        return Mu + block_Lanczos(Sigma_mul,xi,n_mc_smps) #no need to explicitly reshape Mu

    BLOCK_LANC_THRESH = 10000
    draw = small_draw() if nx * M < BLOCK_LANC_THRESH else large_draw() 
    '''

    draw = Mu + torch.mm(torch.cholesky(Sigma), xi)
    draw_reshape = (torch.reshape(torch.transpose(draw, 0, 1), (n_mc_smps, M, nx))).permute(0,2,1)
    return draw_reshape   

def get_GP_samples(Y, T, X, ind_kf, ind_kt, num_obs_times, num_obs_values,
                   num_rnn_grid_times, med_grid,
                   length, noises, Kf, 
                   n_mc_smps, M): 
    """
    returns samples from GP at evenly-spaced gridpoints
    """ 
    # X is the batch-padded data, patients have the same sequence length
    Z = torch.zeros((0, sequence_len, M), device=device)  # M = M(original) + n_meds
        
    N = T.shape[0] # number of observations

    # change indices to 'long'
    ind_kf = ind_kf.type(torch.long)
    ind_kt = ind_kt.type(torch.long)


    i = 0
    while i < N:
        Yi = torch.reshape( Y[i:i+1, 0:num_obs_values[i]], (-1,))
        Ti = torch.reshape( T[i:i+1, 0:num_obs_times[i]], (-1,))
        ind_kfi = torch.reshape( ind_kf[i:i+1, 0:num_obs_values[i]], (-1,))
        ind_kti = torch.reshape( ind_kt[i:i+1, 0:num_obs_values[i]], (-1,))
        Xi = torch.reshape( X[i:i+1, 0:num_rnn_grid_times[i]], (-1,))
        X_len = num_rnn_grid_times[i]    
        
        GP_draws = draw_GP(Yi, Ti, Xi, ind_kfi, ind_kti, length, noises, Kf, n_mc_smps)   
        pad_len = sequence_len - X_len        #Time-axis: pad by this much
        cur_GP_draw = torch.zeros((n_mc_smps, pad_len, GP_draws.shape[2]), dtype=torch.float32, device=device)
        padded_GP_draws = torch.cat((GP_draws, cur_GP_draw), 1)  # pad on time-axis

        meds = torch.tensor(med_grid[i:i+1], device=device) # [1, sequence_len, n_meds]  
        pad_len = sequence_len - meds.shape[1]
        meds = torch.cat([meds,torch.zeros((1,pad_len,meds.shape[2])).cuda(device=device)],1) 
        tiled_meds = meds.repeat(n_mc_smps, 1, 1)

        padded_GPdraws_medcovs = torch.cat((padded_GP_draws, tiled_meds), 2)
        Z = torch.cat((Z, padded_GPdraws_medcovs), 0)    
        i += 1

    return Z

In [None]:
class TransformerModel(nn.Module):

    def __init__(self, ninput, n_covs, sequence_len, emsize, nhead, nhid, nlayers, n_mc_smps, dropout=0.5): # ninput = M(original) + n_meds
        super(TransformerModel, self).__init__()
        from torch.nn import TransformerEncoder, TransformerEncoderLayer
        self.model_type = 'Transformer'
        self.src_mask = None
        self.encoder = nn.Linear(ninput, emsize)
        self.pos_encoder = PositionalEncoding(emsize, dropout)
        encoder_layers = TransformerEncoderLayer(emsize, nhead, nhid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.emsize = emsize
        self.final = torch.rand(size=(sequence_len*(emsize+n_covs),))
        # !!! Needs to check if final_reshape is still diagonal
    
        # GP parameters
        self.M = ninput # self.M = ninput = M + n_meds
        self.n_mc_smps = n_mc_smps
        self.sequence_len = sequence_len
        self.n_covs = n_covs
        self.emsize = emsize
        
        #in fully separable case all labs share same time-covariance
        self.log_length = torch.normal(size=[1], mean=1, std=0.1)
        self.log_noises = torch.normal(size=[self.M], mean=-2, std=0.1)
        self.L_f_init = torch.eye(self.M)
        
        # Wrap into Variable
        self.log_length = torch.nn.Parameter(self.log_length)
        self.log_noises = torch.nn.Parameter(self.log_noises)
        self.L_f_init = torch.nn.Parameter(self.L_f_init)
        self.final =  torch.nn.Parameter(self.final)
        
        self.init_weights()


    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, Y, T, X, ind_kf, ind_kt, num_obs_times, num_obs_values,
                   num_rnn_grid_times, med_grid, covs):      # [batch_size, sequence_len, M]

        
        length = torch.exp(self.log_length)
        noises = torch.exp(self.log_noises)
        Lf = torch.tril(self.L_f_init)  
        Kf = torch.mm(Lf, torch.transpose(Lf, 0, 1))  
        final_reshape = torch.zeros(size=(self.sequence_len*(self.emsize+self.n_covs), self.sequence_len)).to(device)
        for i in range(self.sequence_len):
            final_reshape[:2*(i+1), i] = self.final[:2*(i+1)]
        
        ### Draw samples
        Z = get_GP_samples(Y=Y, 
                            T=T, 
                            X=X, 
                            ind_kf=ind_kf, 
                            ind_kt=ind_kt, 
                            num_obs_times=num_obs_times, 
                            num_obs_values=num_obs_values,
                            num_rnn_grid_times=num_rnn_grid_times, 
                            med_grid=med_grid,
                            length=length, 
                            noises=noises,
                            Kf=Kf, 
                            n_mc_smps=self.n_mc_smps, 
                            M=self.M) 
                                
                            
        # Get shape 
        batch_size_MonteCarlo, _, _ = Z.shape
        batch_size = batch_size_MonteCarlo // self.n_mc_smps
                                                
        # Enbedding       
        src = self.encoder(Z)  # [batch_size_MC, sequence_len, embed_size]
        src = src.permute(1,0,2) # [sequence_len, batch_size_MC, embed_size]
                       
        # Position-Encoder
        src = self.pos_encoder(src)   # [sequence_len, batch_size_MC, embed_size] 
                          
        # Transformer
        output = self.transformer_encoder(src, self.src_mask)   # [sequence_len, batch_size_MC, embed_size] 

        # Append covs
        output = output.permute(1,0,2)   # [batch_size_MC, sequence_len, embed_size]    
        covs = covs.repeat((1,self.n_mc_smps)).repeat((1,self.sequence_len)).view((batch_size*self.n_mc_smps, self.sequence_len, self.n_covs)) # [batch_size_MC, sequence_len, ncovs] 

        output = torch.cat((output, covs), axis=2) # [batch_size_MC, sequence_len, embed_size+ncovs] 
    
        output = torch.reshape(output, (batch_size*self.n_mc_smps, -1))   # [batch_size_MC, sequence_len *(embed_size+ncovs)] 
        
        
        output = torch.reshape(output, (batch_size*self.n_mc_smps, self.sequence_len*(self.emsize + self.n_covs)))  # [batch_size, sequence_len*(embed_size+ncovs)] row-by-row
        output = torch.mm(output, final_reshape)   # [batch_size, sequenceLen]

        # Take mean for Monte Carlo
        output = torch.mean(output.reshape(-1, self.n_mc_smps, self.sequence_len), dim=1)  # [real_batch_size, n_mc_smps, sequenceLen]  
        return output

In [None]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

### data

In [None]:
f = open("input_for_GPRNN-240minutes-168h_2020-10-27.pickle", 'rb')

input_for_GPRNN = pickle.load(f, encoding="latin1")  
f.close()

num_obs_times = input_for_GPRNN['num_obs_times']
num_obs_values = input_for_GPRNN['num_obs_values']
num_rnn_grid_times = input_for_GPRNN['num_rnn_grid_times']
rnn_grid_times = input_for_GPRNN['rnn_grid_times']
labels = input_for_GPRNN['labels']
times = input_for_GPRNN['times']
values = input_for_GPRNN['values']
ind_lvs = input_for_GPRNN['ind_lvs']
ind_times = input_for_GPRNN['ind_times']
meds_on_grid = input_for_GPRNN['meds_on_grid']
covs = input_for_GPRNN['covs']

print("That's all!")

N_tot = len(labels) # total number of patients

seed = 8675309
rs = np.random.RandomState(seed) #fixed seed in np

    
train_test_perm = rs.permutation(N_tot)
val_frac = 0.1 #fraction of full data to set aside for testing

te_ind = train_test_perm[: int(val_frac*N_tot)]
tr_ind = train_test_perm[int(val_frac*N_tot) :]

Nte = len(te_ind)
Ntr = len(tr_ind)

batch_size = 50
eval_batch_size = 10

starts_tr = np.arange(0, Ntr, batch_size)
ends_tr = np.arange(batch_size, Ntr + 1, batch_size)

if len(starts_tr) > len(ends_tr):
    starts_tr = starts_tr[:-1]

starts_te = np.arange(0, Nte, eval_batch_size)
ends_te = np.arange(eval_batch_size, Nte + 1, eval_batch_size)

if len(starts_te) > len(ends_te):
    starts_te = starts_te[:-1]


# Break everything out into train/test
for varname in ['covs', 'labels', 'times', 'values', 'ind_lvs', 'ind_times', 'meds_on_grid', \
               'num_obs_times', 'num_obs_values', 'rnn_grid_times', 'num_rnn_grid_times']:
    print(varname + '_tr = [' + varname + '[i] for i in tr_ind]')
    exec(varname + '_tr = [' + varname + '[i] for i in tr_ind]')

    
for varname in ['covs', 'labels', 'times', 'values', 'ind_lvs', 'ind_times', 'meds_on_grid', \
               'num_obs_times', 'num_obs_values', 'rnn_grid_times', 'num_rnn_grid_times']:
    print(varname + '_te = [' + varname + '[i] for i in te_ind]')
    exec(varname + '_te = [' + varname + '[i] for i in te_ind]')
    
print("data fully setup!") 

### Parameter Setting

In [None]:
M = 25       # Number of features (labs + vitals)
n_covs = 33  # Number of Covariants 
n_meds = 21  # Number of Medicines

ninput = M + n_meds  
emsize = 512 # embedding dimension
nhid = 2048 #  the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 6 #  the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 8   #  the number of heads in the multiheadattention models
dropout = 0.3 # the dropout value


sequence_len = 42  # the maximum sampling time points of all patients
n_mc_smps = 20     # Monte Carlo Samples

model = TransformerModel(ninput=ninput, 
                         n_covs=n_covs, 
                         sequence_len=sequence_len, 
                         emsize=emsize, 
                         nhead=nhead, 
                         nhid=nhid, 
                         nlayers=nlayers,
                         n_mc_smps=n_mc_smps, 
                         dropout=dropout).to(device)


In [None]:
criterion = nn.BCEWithLogitsLoss(reduction='sum')
lr = 0.03  # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)


def train():
    model.train()
    total_loss = 0.
    start_time = time.time()
    
    perm = rs.permutation(Ntr)
    
    batch = 0 
    for s, e in zip(starts_tr, ends_tr):
        batch_start = time.time()
        inds = perm[s:e]
            
        
        T_pad,  Y_pad,  ind_kf_pad,\
        ind_kt_pad, X_pad,  meds_pad \
                        = pad_rawdata(T = [times_tr[i] for i in inds],
                                      Y = [values_tr[i] for i in inds], 
                                      ind_kf = [ind_lvs_tr[i] for i in inds], 
                                      ind_kt = [ind_times_tr[i] for i in inds], 
                                      X = [rnn_grid_times_tr[i] for i in inds],
                                      meds_on_grid = [meds_on_grid_tr[i] for i in inds])
                                      
        ### Move data to tensor 
        T_pad = torch.tensor(T_pad, dtype=torch.float32)
        Y_pad = torch.tensor(Y_pad,  dtype=torch.float32)
        X_pad = torch.tensor(X_pad,  dtype=torch.float32)
        ind_kf_pad = torch.tensor(ind_kf_pad,  dtype=torch.int32)
        ind_kt_pad = torch.tensor(ind_kt_pad,  dtype=torch.int32)
        meds_pad =torch.tensor(meds_pad,  dtype=torch.float32)
        covs = torch.tensor([covs_tr[i] for i in inds],  dtype=torch.float32)
        num_obs_times = torch.tensor([num_obs_times_tr[i] for i in inds], dtype=torch.int32)
        num_obs_values= torch.tensor([num_obs_values_tr[i] for i in inds], dtype=torch.int32)
        num_rnn_grid_times=torch.tensor([num_rnn_grid_times_tr[i] for i in inds], dtype=torch.int32)
        O = torch.tensor([labels_tr[i] for i in inds], dtype=torch.float32)
        O_dup = torch.reshape(O.unsqueeze(1).repeat(1, sequence_len), (-1,))
        
        
        ## Move to device
        T_pad = T_pad.to(device)
        Y_pad = Y_pad.to(device)
        X_pad = X_pad.to(device)
        ind_kf_pad = ind_kf_pad.to(device)
        ind_kt_pad = ind_kt_pad.to(device)
        meds_pad =meds_pad.to(device)
        covs = covs.to(device)
        num_obs_times = num_obs_times.to(device)
        num_obs_values = num_obs_values.to(device)
        num_rnn_grid_times = num_rnn_grid_times.to(device)
        O = O.to(device)
        O_dup = torch.reshape(O.unsqueeze(1).repeat(1, sequence_len), (-1,))
        
        optimizer.zero_grad()
        output = model(Y_pad, T_pad, X_pad, ind_kf_pad, ind_kt_pad, num_obs_times, num_obs_values,
                   num_rnn_grid_times, meds_pad, covs)   # [batch_size, sequence_len]
        loss = criterion(output.view(-1), O_dup)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        log_interval = 1
        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | '
                  'lr {:02.2f} | ms/batch {:5.2f} | '
                  'loss {:5.2f}'.format(
                    epoch, batch,  len(starts_tr), scheduler.get_lr()[0],
                    elapsed * 1000 / log_interval,
                    cur_loss))
            total_loss = 0
            start_time = time.time()

        batch += 1
        
        
def evaluate(eval_model):
    eval_model.eval() 
    total_loss = 0.

    perm = rs.permutation(Nte)
    
    
    output_all = []
    target_all = []
    
    
    with torch.no_grad():
        for s, e in tqdm(zip(starts_te, ends_te)):
            batch_start = time.time()
            inds = perm[s:e]
            
            T_pad,  Y_pad,  ind_kf_pad,\
                ind_kt_pad, X_pad,  meds_pad \
                        = pad_rawdata(T = [times_te[i] for i in inds],
                                      Y = [values_te[i] for i in inds], 
                                      ind_kf = [ind_lvs_te[i] for i in inds], 
                                      ind_kt = [ind_times_te[i] for i in inds], 
                                      X = [rnn_grid_times_te[i] for i in inds],
                                      meds_on_grid = [meds_on_grid_te[i] for i in inds])
                                      

            ### Move data to tensor 
            T_pad = torch.tensor(T_pad, dtype=torch.float32)
            Y_pad = torch.tensor(Y_pad,  dtype=torch.float32)
            X_pad = torch.tensor(X_pad,  dtype=torch.float32)
            ind_kf_pad = torch.tensor(ind_kf_pad,  dtype=torch.int32)
            ind_kt_pad = torch.tensor(ind_kt_pad,  dtype=torch.int32)
            meds_pad =torch.tensor(meds_pad,  dtype=torch.float32)
            covs = torch.tensor([covs_te[i] for i in inds],  dtype=torch.float32)
            num_obs_times = torch.tensor([num_obs_times_te[i] for i in inds], dtype=torch.int32)
            num_obs_values= torch.tensor([num_obs_values_te[i] for i in inds], dtype=torch.int32)
            num_rnn_grid_times=torch.tensor([num_rnn_grid_times_te[i] for i in inds], dtype=torch.int32)
            O = torch.tensor([labels_te[i] for i in inds], dtype=torch.float32)
            O_dup = torch.reshape(O.unsqueeze(1).repeat(1, sequence_len), (-1,))
            
        
            ### Move to device
            T_pad = T_pad.to(device)
            Y_pad = Y_pad.to(device)
            X_pad = X_pad.to(device)
            ind_kf_pad = ind_kf_pad.to(device)
            ind_kt_pad = ind_kt_pad.to(device)
            meds_pad =meds_pad.to(device)
            covs = covs.to(device)
            num_obs_times = num_obs_times.to(device)
            num_obs_values = num_obs_values.to(device)
            num_rnn_grid_times = num_rnn_grid_times.to(device)
            O = O.to(device)
            O_dup = torch.reshape(O.unsqueeze(1).repeat(1, sequence_len), (-1,))

            output = model(Y_pad, T_pad, X_pad, ind_kf_pad, ind_kt_pad, num_obs_times, num_obs_values,
                   num_rnn_grid_times, meds_pad, covs)

            
            output_all.extend(list(output.cpu().numpy()))
            target_all.extend(list(O.cpu().numpy()))
            
            total_loss += criterion(output.view(-1), O_dup).item()
            
    return total_loss, output_all, target_all


In [None]:
best_val_loss = float("inf")
epochs = 100
best_model = None

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train()
    val_loss, _, _ = evaluate(model)
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} '
          .format(epoch, (time.time() - epoch_start_time),
                                     val_loss))
    print('-' * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = model

    scheduler.step()

In [None]:
val_loss, output_all, target_all = evaluate(best_model)
print('-' * 89)
print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} '
      .format(epoch, (time.time() - epoch_start_time),
                                 val_loss))