In [2]:
from keras import backend as K
from keras.engine.topology import Layer
import numpy as np
import tensorflow as tf
import keras
from keras import activations


class single_channel_interp(Layer):

    def __init__(self, ref_points, hours_look_ahead, **kwargs):
        self.ref_points = ref_points
        self.hours_look_ahead = hours_look_ahead  # in hours
        super(single_channel_interp, self).__init__(**kwargs)

    def build(self, input_shape):
        #input_shape [batch, features, time_stamp]
        self.time_stamp = input_shape[2]
        self.d_dim = input_shape[1] // 4
        self.activation = activations.get('sigmoid')
        self.kernel = self.add_weight(
            name='kernel',
            shape=(self.d_dim, ),
            initializer=keras.initializers.Constant(value=0.0),
            trainable=True)
        super(single_channel_interp, self).build(input_shape)

    def call(self, x, reconstruction=False):
        self.reconstruction = reconstruction
        x_t = x[:, :self.d_dim, :]
        d = x[:, 2*self.d_dim:3*self.d_dim, :]
        if reconstruction:
            output_dim = self.time_stamp
            m = x[:, 3*self.d_dim:, :]
            ref_t = K.tile(d[:, :, None, :], (1, 1, output_dim, 1))
        else:
            m = x[:, self.d_dim: 2*self.d_dim, :]
            ref_t = np.linspace(0, self.hours_look_ahead, self.ref_points)
            output_dim = self.ref_points
            ref_t.shape = (1, ref_t.shape[0])
        #x_t = x_t*m
        d = K.tile(d[:, :, :, None], (1, 1, 1, output_dim))
        mask = K.tile(m[:, :, :, None], (1, 1, 1, output_dim))
        x_t = K.tile(x_t[:, :, :, None], (1, 1, 1, output_dim))
        norm = (d - ref_t)*(d - ref_t)
        a = K.ones((self.d_dim, self.time_stamp, output_dim))
        pos_kernel = K.log(1 + K.exp(self.kernel))
        alpha = a*pos_kernel[:, np.newaxis, np.newaxis]
        w = K.logsumexp(-alpha*norm + K.log(mask), axis=2)
        w1 = K.tile(w[:, :, None, :], (1, 1, self.time_stamp, 1))
        w1 = K.exp(-alpha*norm + K.log(mask) - w1)
        y = K.sum(w1*x_t, axis=2)
        if reconstruction:
            rep1 = tf.concat([y, w], 1)
        else:
            w_t = K.logsumexp(-10.0*alpha*norm + K.log(mask),
                              axis=2)  # kappa = 10
            w_t = K.tile(w_t[:, :, None, :], (1, 1, self.time_stamp, 1))
            w_t = K.exp(-10.0*alpha*norm + K.log(mask) - w_t)
            y_trans = K.sum(w_t*x_t, axis=2)
            rep1 = tf.concat([y, w, y_trans], 1)
        return rep1

    def compute_output_shape(self, input_shape):
        if self.reconstruction:
            return (input_shape[0], 2*self.d_dim, self.time_stamp)
        return (input_shape[0], 3*self.d_dim, self.ref_points)


class cross_channel_interp(Layer):

    def __init__(self, **kwargs):
        super(cross_channel_interp, self).__init__(**kwargs)

    def build(self, input_shape):
        self.d_dim = input_shape[1] // 3
        self.activation = activations.get('sigmoid')
        self.cross_channel_interp = self.add_weight(
            name='cross_channel_interp',
            shape=(self.d_dim, self.d_dim),
            initializer=keras.initializers.Identity(gain=1.0),
            trainable=True)

        super(cross_channel_interp, self).build(input_shape)

    def call(self, x, reconstruction=False):
        self.reconstruction = reconstruction
        self.output_dim = K.int_shape(x)[-1]
        cross_channel_interp = self.cross_channel_interp
        y = x[:, :self.d_dim, :]
        w = x[:, self.d_dim:2*self.d_dim, :]
        intensity = K.exp(w)
        y = tf.transpose(y, perm=[0, 2, 1])
        w = tf.transpose(w, perm=[0, 2, 1])
        w2 = w
        w = K.tile(w[:, :, :, None], (1, 1, 1, self.d_dim))
        den = K.logsumexp(w, axis=2)
        w = K.exp(w2 - den)
        mean = K.mean(y, axis=1)
        mean = K.tile(mean[:, None, :], (1, self.output_dim, 1))
        w2 = K.dot(w*(y - mean), cross_channel_interp) + mean
        rep1 = tf.transpose(w2, perm=[0, 2, 1])
        if reconstruction is False:
            y_trans = x[:, 2*self.d_dim:3*self.d_dim, :]
            y_trans = y_trans - rep1  # subtracting smooth from transient part
            rep1 = tf.concat([rep1, intensity, y_trans], 1)
        return rep1

    def compute_output_shape(self, input_shape):
        if self.reconstruction:
            return (input_shape[0], self.d_dim, self.output_dim)
        return (input_shape[0], 3*self.d_dim, self.output_dim)

Using TensorFlow backend.


In [17]:
from keras.layers import Input, Dense, GRU, Lambda, Permute
from keras.models import Model


def interp_net(num_features, timestamp, ref_points, hours_look_ahead, units, recurrent_dropout):
    main_input = Input(shape=(4*num_features, timestamp), name='input')
    sci = single_channel_interp(ref_points, hours_look_ahead)
    cci = cross_channel_interp()
    interp = cci(sci(main_input))
    reconst = cci(sci(main_input, reconstruction=True),
                  reconstruction=True)
    aux_output = Lambda(lambda x: x, name='aux_output')(reconst)
    z = Permute((2, 1))(interp)
    z = GRU(units, activation='tanh', recurrent_dropout=recurrent_dropout, dropout=recurrent_dropout)(z)
    main_output = Dense(1, activation='sigmoid', name='main_output')(z)
    model = Model([main_input], [main_output, aux_output])
    
    print(model.summary())
    return model

In [35]:
def create_customloss(feature_std, num_features):
    def customloss(ytrue, ypred):
        """ Autoencoder loss
        """
        # standard deviation of each feature mentioned in paper for MIMIC_III data
        wc = feature_std
        wc.shape = (1, num_features)
        y = ytrue[:, :num_features, :]
        m2 = ytrue[:, 3*num_features:4*num_features, :]
        m2 = 1 - m2
        m1 = ytrue[:, num_features:2*num_features, :]
        m = m1*m2
        ypred = ypred[:, :num_features, :]
        x = (y - ypred)*(y - ypred)
        x = x*m
        count = tf.reduce_sum(m, axis=2)
        count = tf.where(count > 0, count, tf.ones_like(count))
        x = tf.reduce_sum(x, axis=2)/count
        x = x/(wc**2)  # dividing by standard deviation
        x = tf.reduce_sum(x, axis=1)/num_features
        return tf.reduce_mean(x)
    return customloss


In [113]:
from interp_net import load_data
train_ip, train_op, valid_ip, valid_op, test_ip, test_op, feature_std = load_data(
    data_dir='./data_interp_net'
)

In [114]:
test_ip.shape

(516, 239)

In [None]:
model.predict(test_ip[0:8])

In [31]:
timestamp = train_ip.shape[2]
num_features = train_ip.shape[1] // 4

model = interp_net(
    num_features=num_features,
    timestamp=timestamp,
    ref_points=96,
    hours_look_ahead=24,
    units=100,
    recurrent_dropout=0.2
)

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input (InputLayer)              (None, 516, 239)     0                                            
__________________________________________________________________________________________________
single_channel_interp_3 (single multiple             129         input[0][0]                      
                                                                 input[0][0]                      
__________________________________________________________________________________________________
cross_channel_interp_3 (cross_c multiple             16641       single_channel_interp_3[0][0]    
                                                                 single_channel_interp_3[1][0]    
__________________________________________________________________________________________________
permute_3 

In [91]:
from interp_net import generate_data
import pickle
from tqdm import tqdm


data_path = './interp_net_mimic_iii_preprocessed.pkl'
start_hour=0
end_hour=24
input_dropout=0.2


data, oc, train_ind, valid_ind, test_ind = pickle.load(open(data_path, 'rb'))
# Filter labeled data in first 24h.
data = data.loc[data.ts_ind.isin(np.concatenate((train_ind, valid_ind, test_ind), axis=-1))]
data = data.loc[(data.hour>=start_hour)&(data.hour<=end_hour)]

oc = oc.loc[oc.ts_ind.isin(np.concatenate((train_ind, valid_ind, test_ind), axis=-1))]
# Fix age.
data.loc[(data.variable=='Age')&(data.value>200), 'value'] = 91.4
# Get y and N.
y = np.array(oc.sort_values(by='ts_ind')['in_hospital_mortality']).astype('float32')
N = data.ts_ind.max() + 1
# Get static data with mean fill and missingness indicator.
static_varis = ['Age', 'Gender']
ii = data.variable.isin(static_varis)
static_data = data.loc[ii]
data = data.loc[~ii]
def inv_list(l, start=0):
    d = {}
    for i in range(len(l)):
        d[l[i]] = i+start
    return d
static_var_to_ind = inv_list(static_varis)
D = len(static_varis)
demo = np.zeros((N, D))
for row in tqdm(static_data.itertuples()):
    demo[row.ts_ind, static_var_to_ind[row.variable]] = row.value
# Normalize static data.
means = demo.mean(axis=0, keepdims=True)
stds = demo.std(axis=0, keepdims=True)
stds = (stds==0)*1 + (stds!=0)*stds
demo = (demo-means)/stds
# Trim to max len.
data = data.sample(frac=1)
print(data.groupby('ts_ind')['hour'].nunique().quantile([0.25, 0.5, 0.75, 0.9, 0.99]))

max_timestep = int(data.groupby('ts_ind')['hour'].nunique().quantile(0.99))

# Get N, V, var_to_ind.
N = data.ts_ind.max() + 1
varis = sorted(list(set(data.variable)))
V = len(varis)
def inv_list(l, start=0):
    d = {}
    for i in range(len(l)):
        d[l[i]] = i+start
    return d

var_to_ind = inv_list(varis, start=1)
data['vind'] = data.variable.map(var_to_ind)
data = data[['ts_ind', 'vind', 'hour', 'value']]
# Add obs index.
data = data.sort_values(by=['ts_ind', 'hour', 'vind']).reset_index(drop=True)
data = data.reset_index().rename(columns={'index':'obs_ind'})
data = data.merge(data.groupby('ts_ind').agg({'obs_ind':'min'}).reset_index().rename(columns={ \
                                                            'obs_ind':'first_obs_ind'}), on='ts_ind')
data['obs_ind'] = data['obs_ind'] - data['first_obs_ind']
# Find max_timestep.
print ('max_timestep', max_timestep)


print(data.groupby(by=['vind'])['value'].std().sort_index())
feature_std = np.array(data.groupby(by=['vind'])['value'].std().sort_index())

89624it [00:00, 692447.39it/s]


0.25     36.0
0.50     56.0
0.75     94.0
0.90    143.0
0.99    239.0
Name: hour, dtype: float64
max_timestep 239


17951791it [00:55, 322053.93it/s]


vind
1       149.247326
2       858.097448
3      1293.136791
4         0.661482
5        81.894825
          ...     
125       3.619471
126      10.609918
127      23.654603
128       0.083971
129       0.889520
Name: value, Length: 129, dtype: float64


In [92]:
data.groupby(by=['vind'])['value'].std().sort_index()

vind
1       149.247326
2       858.097448
3      1293.136791
4         0.661482
5        81.894825
          ...     
125       3.619471
126      10.609918
127      23.654603
128       0.083971
129       0.889520
Name: value, Length: 129, dtype: float64