In [2]:
import tensorflow as tf
import math
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
import cPickle as pickle
import h5py
import random

In [74]:
class MTGRU_AE(object):

    def init_weights(self, input_dim, output_dim, name=None, std=1.0):
        return tf.Variable(tf.truncated_normal([input_dim, output_dim], stddev=std / math.sqrt(input_dim)), name=name)

    def init_bias(self, output_dim, name=None):
        return tf.Variable(tf.zeros([output_dim]), name=name)

    def __init__(self, visit_num, visit_length, one_hot_input_dim, input_dim, info_dim, output_dim, output_dim1,
                 output_dim2, output_dim3, voutput_dim, vhidden_dim, hidden_dim1, hidden_dim2, hidden_dim3,
                 hidden_dim4,patient_visit_num,patient_visit_length,visit_time,visit_info,patient_num,code_inputindex,code_time):


        # 一个病人visit的个数
        self.visit_num = visit_num
        # 一个visit中的code的个数
        self.visit_length = visit_length

        # 输入的one-hot的维度
        self.one_hot_input_dim = one_hot_input_dim
        # one-hot变换后作为网络输入的维度
        self.input_dim = input_dim
        # 患者信息
        self.info_dim = info_dim
        # 最后一层gru最终输出
        self.output_dim = output_dim
        # 中间层输出,每层gru有一个输出，visit层有一个输出
        self.output_dim1 = output_dim1
        self.output_dim2 = output_dim2
        self.output_dim3 = output_dim3
        self.voutput_dim = voutput_dim
        # 中间隐藏层，四个gru层，一个visit层
        self.hidden_dim1 = hidden_dim1
        self.hidden_dim2 = hidden_dim2
        self.hidden_dim3 = hidden_dim3
        self.hidden_dim4 = hidden_dim4
        self.vhidden_dim = vhidden_dim
        # one-hot后的矩阵参数
        self.Wi = self.init_weights(one_hot_input_dim, input_dim, name='OneHot_w')
        self.bi = self.init_bias(input_dim, name='OneHot_w')

        # encoder的GRU参数
        self.Wz_enc = self.init_weights(input_dim, hidden_dim1, name='Update_wx_enc')
        self.Uz_enc = self.init_weights(hidden_dim1, hidden_dim1, name='Update_wh_enc')
        self.bz_enc = self.init_bias(hidden_dim1, name='Update_bias_enc')

        self.Wz_enc2 = self.init_weights(output_dim1, hidden_dim2, name='Update_wx_enc2')
        self.Uz_enc2 = self.init_weights(hidden_dim2, hidden_dim2, name='Update_wh_enc2')
        self.bz_enc2 = self.init_bias(hidden_dim2, name='Update_bias_enc2')

        self.Wr_enc = self.init_weights(input_dim, hidden_dim1, name='Reset_wx_enc')
        self.Ur_enc = self.init_weights(hidden_dim1, hidden_dim1, name='Reset_wh_enc')
        self.br_enc = self.init_bias(hidden_dim1, name='Reset_bias_enc')

        self.Wr_enc2 = self.init_weights(output_dim1, hidden_dim2, name='Reset_wx_enc2')
        self.Ur_enc2 = self.init_weights(hidden_dim2, hidden_dim2, name='Reset_wh_enc2')
        self.br_enc2 = self.init_bias(hidden_dim2, name='Reset_bias_enc2')

        self.Wd_enc = tf.ones([1, self.hidden_dim1], dtype=tf.float32, name='Decay_w_enc')
        self.Wd_enc2 = tf.ones([1, self.hidden_dim2], dtype=tf.float32, name='Decay_w_enc2')

        self.Wh_enc = self.init_weights(self.input_dim, self.hidden_dim1, name='Canditateh_wx_enc')
        self.Uh_enc = self.init_weights(self.hidden_dim1, self.hidden_dim1, name='Canditateh_wh_enc')
        self.bh_enc = self.init_bias(self.hidden_dim1, name='Canditateh_bias_enc')

        self.Wh_enc2 = self.init_weights(self.output_dim1, self.hidden_dim2, name='Canditateh_wx_enc2')
        self.Uh_enc2 = self.init_weights(self.hidden_dim2, self.hidden_dim2, name='Canditateh_wh_enc2')
        self.bh_enc2 = self.init_bias(self.hidden_dim2, name='Canditateh_bias_enc2')

        # decoder的GRU参数
        self.Wz_dec = self.init_weights(voutput_dim, hidden_dim3, name='Update_wx_dec')
        self.Uz_dec = self.init_weights(hidden_dim3, hidden_dim3, name='Update_wh_dec')
        self.bz_dec = self.init_bias(hidden_dim3, name='Update_bias_dec')

        self.Wz_dec2 = self.init_weights(output_dim3, hidden_dim4, name='Update_wx_dec2')
        self.Uz_dec2 = self.init_weights(hidden_dim4, hidden_dim4, name='Update_wh_dec2')
        self.bz_dec2 = self.init_bias(hidden_dim4, name='Update_bias_dec2')

        self.Wr_dec = self.init_weights(voutput_dim, hidden_dim3, name='Reset_wx_dec')
        self.Ur_dec = self.init_weights(hidden_dim3, hidden_dim3, name='Reset_wh_dec')
        self.br_dec = self.init_bias(hidden_dim3, name='Reset_bias_dec')

        self.Wr_dec2 = self.init_weights(output_dim3, hidden_dim4, name='Reset_wx_dec2')
        self.Ur_dec2 = self.init_weights(hidden_dim4, hidden_dim4, name='Reset_wh_dec2')
        self.br_dec2 = self.init_bias(hidden_dim4, name='Reset_bias_dec2')

        self.Wd_dec = tf.ones([1, self.hidden_dim3], dtype=tf.float32, name='Decay_w_dec')
        self.Wd_dec2 = tf.ones([1, self.hidden_dim4], dtype=tf.float32, name='Decay_w_dec2')

        self.Wh_dec = self.init_weights(self.voutput_dim, self.hidden_dim3, name='Canditateh_wx_dec')
        self.Uh_dec = self.init_weights(self.hidden_dim3, self.hidden_dim3, name='Canditateh_wh_dec')
        self.bh_dec = self.init_bias(self.hidden_dim3, name='Canditateh_bias_dec')

        self.Wh_dec2 = self.init_weights(self.output_dim3, self.hidden_dim4, name='Canditateh_wx_dec')
        self.Uh_dec2 = self.init_weights(self.hidden_dim4, self.hidden_dim4, name='Canditateh_wh_dec')
        self.bh_dec2 = self.init_bias(self.hidden_dim4, name='Canditateh_bias_dec')

        # visit层的参数
        # 输入是上一层的输出加上病人信息
        self.Wz_v = self.init_weights(output_dim2 + info_dim, vhidden_dim, name='Update_wx_v')
        self.Uz_v = self.init_weights(vhidden_dim, vhidden_dim, name='Update_wh_v')
        self.bz_v = self.init_bias(vhidden_dim, name='Update_bias_dec')

        self.Wr_v = self.init_weights(output_dim2 + info_dim, vhidden_dim, name='Reset_wx_v')
        self.Ur_v = self.init_weights(vhidden_dim, vhidden_dim, name='Reset_wh_v')
        self.br_v = self.init_bias(vhidden_dim, name='Reset_bias_v')

        self.Wd_v = tf.ones([1, self.vhidden_dim], dtype=tf.float32, name='Decay_w_v')

        self.Wh_v = self.init_weights(self.output_dim2 + info_dim, self.vhidden_dim, name='Canditateh_wx_v')
        self.Uh_v = self.init_weights(self.vhidden_dim, self.vhidden_dim, name='Canditateh_wh_v')
        self.bh_v = self.init_bias(self.vhidden_dim, name='Canditateh_bias_v')

        # 输出层
        # visit层的输出
        self.Wov = self.init_weights(vhidden_dim, voutput_dim, name='Visit_weight')
        self.bov = self.init_bias(voutput_dim, name='Visit_bias')

        # encoder1
        self.Wo1 = self.init_weights(hidden_dim1, output_dim1, name='Output1_weight')
        self.bo1 = self.init_bias(output_dim1, name='Output1_bias')

        # encoder2
        self.Wo2 = self.init_weights(hidden_dim2, output_dim2, name='Output2_weight')
        self.bo2 = self.init_bias(output_dim2, name='Output2_bias')

        # decoder1
        self.Wo3 = self.init_weights(hidden_dim3, output_dim3, name='Output3_weight')
        self.bo3 = self.init_bias(output_dim3, name='Output3_bias')

        # 最终输出(decoder2)
        self.Wo = self.init_weights(hidden_dim4, output_dim, name='Output_w')
        self.bo = self.init_bias(output_dim, name='Output_bias')

        # 输入占位符
        # [batch size x seq length x input dim]
        self.inputindex=tf.placeholder(dtype=tf.int32,shape=[None,None])
        self.one_hot_input=tf.one_hot(indices=self.inputindex, depth=self.one_hot_input_dim, axis=2)
        #self.one_hot_input = tf.placeholder('float', shape=[None, None, self.one_hot_input_dim])
        self.time = tf.placeholder('float', [None, None])
        self.keep_prob = tf.placeholder(tf.float32)
        # 患者人口信息
        self.info = tf.placeholder('float', shape=[None, None, self.info_dim])

        # 为了获得所有patien的visit和 patient向量，输入全部信息
        self.patient_visit_num = patient_visit_num # 一个患者的visit个数
        self.patient_visit_length = patient_visit_length # 一个患者的visit中的code个数
        self.visit_time = visit_time
        self.visit_info = visit_info
        self.patient_num=patient_num
        self.code_inputindex=code_inputindex #一个病人所有code的index
        self.code_time=code_time



    # 对输入的one-hot使用矩阵处理，code的嵌入向量1
    def get_input(self, one_hot_input):
        input = tf.matmul(one_hot_input, self.Wi) + self.bi
        return input

    # encoder第一层TGRU cell，code的嵌入向量2
    def TGRU_encoder_cell1(self, prev_h, concat_input):

        # concat_input:[batch_size x input_dim+1]
        batch_size = tf.shape(concat_input)[0]
        x = tf.slice(concat_input, [0, 1], [batch_size, self.input_dim])
        t = tf.slice(concat_input, [0, 0], [batch_size, 1])

        ft = self.map_elapse_time(t, self.hidden_dim1)

        z = tf.sigmoid(tf.matmul(x, self.Wz_enc) + tf.matmul(prev_h, self.Uz_enc) + self.bz_enc)
        r = tf.sigmoid(tf.matmul(x, self.Wr_enc) + tf.matmul(prev_h, self.Ur_enc) + self.br_enc)
        d = tf.matmul(ft, self.Wd_enc)
        h_ = tf.multiply(d, prev_h)
        h_canditate = tf.sigmoid(tf.matmul(x, self.Wh_enc) + tf.matmul(r * h_, self.Uh_enc) + self.bh_enc)

        current_h = h_ - z * h_ + z * h_canditate

        return current_h

    # encoder第二层TGRU cell
    def TGRU_encoder_cell2(self, prev_h, concat_input):

        batch_size = tf.shape(concat_input)[0]
        x = tf.slice(concat_input, [0, 1], [batch_size, self.output_dim1])
        t = tf.slice(concat_input, [0, 0], [batch_size, 1])

        ft = self.map_elapse_time(t, self.hidden_dim2)

        z = tf.sigmoid(tf.matmul(x, self.Wz_enc2) + tf.matmul(prev_h, self.Uz_enc2) + self.bz_enc2)
        r = tf.sigmoid(tf.matmul(x, self.Wr_enc2) + tf.matmul(prev_h, self.Ur_enc2) + self.br_enc2)
        d = tf.matmul(ft, self.Wd_enc2)
        h_ = tf.multiply(d, prev_h)
        h_canditate = tf.sigmoid(tf.matmul(x, self.Wh_enc2) + tf.matmul(r * h_, self.Uh_enc2) + self.bh_enc2)
        current_h = h_ - z * h_ + z * h_canditate

        return current_h

    # visit层的TGURU cell
    # 此时的输入与其他cell的不同，concat_input应该是多个时刻的code的，而其他的输入都是一个时刻的code，因此在后面对get_encoder_h2输出的结果做处理
    # 此时的concat_input是一个三维向量，也就是有多时刻的code，每个时刻有一个时间和batch个数个code
    # [visit_length x batch_size x 1+input_dim ]
    # 创建每个visit的长度的参数，来方便创建多个输入x
    def TGRU_visit_cell(self, prev_h, concat_input):

        # visit_length x batch_size x input_dim+info_dim+info_dim+1
        batch_size = tf.shape(concat_input)[1]

        # 只获取第一个时间
        t = tf.slice(concat_input, [0, 0, 0], [1, batch_size, 1])
        # 只获取第一个info
        info = tf.slice(concat_input, [0, 0, 1], [1, batch_size, self.info_dim])

        # 循环创建局部变量
        for i in range(self.visit_length):
            locals()['x' + str(i)] = tf.slice(concat_input, [i, 0, 1 + self.info_dim],
                                              [1, batch_size, self.output_dim2])

        # 获得visit层的输入
        visit_x = tf.zeros([batch_size, self.output_dim2], dtype=tf.float32)
        for i in range(self.visit_length):
            visit_x += locals()['x' + str(i)]

        visit_x = visit_x / self.visit_length

        # 把输入和病人信息放在一个向量中
        visit_x = tf.concat([visit_x, info], 2)
        # 3维变2维 第一维是大小是1，没用
        visit_x = tf.reshape(visit_x,[tf.shape(visit_x)[1],tf.shape(visit_x)[2]])
        t = tf.reshape(t,[tf.shape(t)[1],tf.shape(t)[2]])


        # visit嵌入向量经过GRU输出
        ft = self.map_elapse_time(t, self.vhidden_dim)
        z = tf.sigmoid(tf.matmul(visit_x, self.Wz_v) + tf.matmul(prev_h, self.Uz_v) + self.bz_v)
        r = tf.sigmoid(tf.matmul(visit_x, self.Wr_v) + tf.matmul(prev_h, self.Ur_v) + self.br_v)
        d = tf.matmul(ft, self.Wd_v)
        h_ = tf.multiply(d, prev_h)
        h_canditate = tf.sigmoid(tf.matmul(visit_x, self.Wh_v) + tf.matmul(r * h_, self.Uh_v) + self.bh_v)
        current_h = h_ - z * h_ + z * h_canditate

        return current_h

    # decoder第一层TGRU cell
    def TGRU_decoder_cell1(self, prev_h, concat_input):

        batch_size = tf.shape(concat_input)[0]
        x = tf.slice(concat_input, [0, 1], [batch_size, self.voutput_dim])
        t = tf.slice(concat_input, [0, 0], [batch_size, 1])

        ft = self.map_elapse_time(t, self.hidden_dim3)

        z = tf.sigmoid(tf.matmul(x, self.Wz_dec) + tf.matmul(prev_h, self.Uz_dec) + self.bz_dec)
        r = tf.sigmoid(tf.matmul(x, self.Wr_dec) + tf.matmul(prev_h, self.Ur_dec) + self.br_dec)
        d = tf.matmul(ft, self.Wd_dec)
        h_ = tf.multiply(d, prev_h)
        h_canditate = tf.sigmoid(tf.matmul(x, self.Wh_dec) + tf.matmul(r * h_, self.Uh_dec) + self.bh_dec)
        current_h = h_ - z * h_ + z * h_canditate

        return current_h

    # decoder第二层TGRU cell
    def TGRU_decoder_cell2(self, prev_h, concat_input):

        batch_size = tf.shape(concat_input)[0]
        x = tf.slice(concat_input, [0, 1], [batch_size, self.output_dim3])
        t = tf.slice(concat_input, [0, 0], [batch_size, 1])

        ft = self.map_elapse_time(t, self.hidden_dim4)

        z = tf.sigmoid(tf.matmul(x, self.Wz_dec2) + tf.matmul(prev_h, self.Uz_dec2) + self.bz_dec2)
        r = tf.sigmoid(tf.matmul(x, self.Wr_dec2) + tf.matmul(prev_h, self.Ur_dec2) + self.br_dec2)
        d = tf.matmul(ft, self.Wd_dec2)
        h_ = tf.multiply(d, prev_h)
        h_canditate = tf.sigmoid(tf.matmul(x, self.Wh_dec2) + tf.matmul(r * h_, self.Uh_dec2) + self.bh_dec2)
        current_h = h_ - z * h_ + z * h_canditate

        return current_h

    # 每个输出层的操作
    # encoder1的输出
    def get_output1(self, h):
        output = tf.matmul(h, self.Wo1) + self.bo1
        # output = tf.nn.softmax(tf.nn.relu(tf.matmul(state, self.Wo) + self.bo))
        return output

    # encoder2的输出
    def get_output2(self, h):
        output = tf.matmul(h, self.Wo2) + self.bo2
        # output = tf.nn.softmax(tf.nn.relu(tf.matmul(state, self.Wo) + self.bo))
        return output

    # visit的输出
    def get_outputv(self, h):
        output = tf.matmul(h, self.Wov) + self.bov
        # output = tf.nn.softmax(tf.nn.relu(tf.matmul(state, self.Wo) + self.bo))
        return output

    # decoder1的输出
    def get_output3(self, h):
        output = tf.matmul(h, self.Wo3) + self.bo3
        # output = tf.nn.softmax(tf.nn.relu(tf.matmul(state, self.Wo) + self.bo))
        return output

    # 最终输出（decoder2的输出）
    def get_output(self, h):
        output = tf.matmul(h, self.Wo) + self.bo
        # output = tf.nn.softmax(tf.nn.relu(tf.matmul(state, self.Wo) + self.bo))
        return output

    # encoder1的输出h
    def get_encoder1_h(self):  # Returns all hidden states for the samples in a batch

        convert_input = self.get_input(self.one_hot_input)
        batch_size = tf.shape(convert_input)[0]
        scan_input = tf.transpose(convert_input, perm=[1, 0, 2])  # scan input is [seq_length x batch_size x input_dim]
        scan_time = tf.transpose(self.time)  # scan_time [seq_length x batch_size]

        initial_hidden = tf.zeros([batch_size, self.hidden_dim1], tf.float32)

        # make scan_time [seq_length x batch_size x 1]
        scan_time = tf.reshape(scan_time, [tf.shape(scan_time)[0], tf.shape(scan_time)[1], 1])
        concat_input = tf.concat([scan_time, scan_input], 2)  # [seq_length x batch_size x input_dim+1]

        encoder1_h = tf.scan(self.TGRU_encoder_cell1, concat_input, initializer=initial_hidden, name='encoder1_h')

        return encoder1_h

    # encoder2的输出h
    def get_encoder2_h(self):  # Returns all hidden states for the samples in a batch

        encoder1_h = self.get_encoder1_h()
        encoder1_output = tf.map_fn(self.get_output1, encoder1_h)

        batch_size = tf.shape(encoder1_h)[1]
        scan_time = tf.transpose(self.time)  # scan_time [seq_length x batch_size]
        initial_hidden = tf.zeros([batch_size, self.hidden_dim2], tf.float32)

        # make scan_time [seq_length x batch_size x 1]
        scan_time = tf.reshape(scan_time, [tf.shape(scan_time)[0], tf.shape(scan_time)[1], 1])
        concat_input = tf.concat([scan_time, encoder1_output], 2)  # [seq_length x batch_size x input_dim+1]

        encoder2_h = tf.scan(self.TGRU_encoder_cell2, concat_input, initializer=initial_hidden, name='encoder2_h')
        return encoder2_h

    # visit的输出h
    def get_visit_h(self):  # Returns all hidden states for the samples in a batch

        encoder2_h = self.get_encoder2_h()
        encoder2_output = tf.map_fn(self.get_output2, encoder2_h)

        batch_size = tf.shape(encoder2_h)[1]
        scan_time = tf.transpose(self.time)  # scan_time [seq_length x batch_size]

        # make info [batch_size x seq_length x info_dim] --> [seq_length x batch_size x info_dim]
        scan_info = tf.transpose(self.info, [1, 0, 2])
        # make scan_time [seq_length x batch_size x 1]
        scan_time = tf.reshape(scan_time, [tf.shape(scan_time)[0], tf.shape(scan_time)[1], 1])

        # [seq_length x batch_size x input_dim+info_dim+1]
        concat_input = tf.concat([scan_time, scan_info, encoder2_output], 2)

        # 已经知道了visit的个数和长度，可以直接把concat_input拆分成visit个
        # [visit_num x visit_length x batch_size x input_dim+info_dim_1]
        initial_hidden = tf.zeros([ batch_size, self.vhidden_dim], tf.float32)
        concat_input = tf.reshape(concat_input, [self.visit_num, self.visit_length, tf.shape(concat_input)[1],
                                                 tf.shape(concat_input)[2]])
        visit_h = tf.scan(self.TGRU_visit_cell, concat_input, initializer=initial_hidden, name='visit')

        return visit_h



    # 获得visit和patient的表达
    def get_patient_vector(self):
        visit_h = self.get_visit_vector()
        # 最后一组h，把h数组反向取第一组
        pv = tf.reverse(visit_h,[0])[0, :, :]
        return visit_h,pv

    # 得到decoder第1层的h
    # 此时的输入是visit的h经过output后的结果,，要注意对一个visit的output展开成这个visit的code个
    def get_decoder1_h(self):
        visit_h = self.get_visit_h()
        visit_output_ = tf.map_fn(self.get_outputv, visit_h)
        # 对输入展开
        visit_output = []
        for i in range(self.visit_num):
            for j in range(self.visit_length):
                visit_output.append(visit_output_[i])

        batch_size = tf.shape(visit_h)[1]
        scan_time = tf.transpose(self.time)  # scan_time [seq_length x batch_size]
        initial_hidden = tf.zeros([batch_size, self.hidden_dim3], tf.float32)

        # make scan_time [seq_length x batch_size x 1]
        scan_time = tf.reshape(scan_time, [tf.shape(scan_time)[0], tf.shape(scan_time)[1], 1])
        concat_input = tf.concat([scan_time, visit_output], 2)  # [seq_length x batch_size x input_dim+1]

        decoder1_h = tf.scan(self.TGRU_decoder_cell1, concat_input, initializer=initial_hidden, name='decoder1_h')
        return decoder1_h

    # decoder2的输出h
    def get_decoder2_h(self):  # Returns all hidden states for the samples in a batch

        decoder1_h = self.get_decoder1_h()
        decoder1_output = tf.map_fn(self.get_output3, decoder1_h)

        batch_size = tf.shape(decoder1_h)[1]
        scan_time = tf.transpose(self.time)  # scan_time [seq_length x batch_size]
        initial_hidden = tf.zeros([batch_size, self.hidden_dim4], tf.float32)

        # make scan_time [seq_length x batch_size x 1]
        scan_time = tf.reshape(scan_time, [tf.shape(scan_time)[0], tf.shape(scan_time)[1], 1])
        concat_input = tf.concat([scan_time, decoder1_output], 2)  # [seq_length x batch_size x input_dim+1]

        decoder2_h = tf.scan(self.TGRU_decoder_cell2, concat_input, initializer=initial_hidden, name='decoder2_h')
        return decoder2_h

    # 获得整个AE的输出
    def get_decoder_outputs(self):  # Returns the output of only the last time step
        decoder2_h = self.get_decoder2_h()
        all_outputs = tf.map_fn(self.get_output, decoder2_h)
        outputs = tf.transpose(all_outputs, perm=[1, 0, 2])
        return outputs

    # 获得预测和输入的距离使用MSE
    def get_reconstruction_loss(self):
        outputs = self.get_decoder_outputs()
        loss = tf.reduce_mean(tf.square(self.one_hot_input - outputs))
        return loss

    def get_cross_loss(self):
        outputs = self.get_decoder_outputs()
        cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.one_hot_input, logits=outputs))
        return cross_entropy

    def map_elapse_time(self, t, dim):
        c1 = tf.constant(1, dtype=tf.float32)
        c2 = tf.constant(2.7183, dtype=tf.float32)

        T = tf.div(c1, tf.log(t + c2), name='Log_elapse_time')
        # T = tf.div(c1, tf.add(t , c1), name='Log_elapse_time')

        return T

    # 获取code，visit，patient的表达
    def get_c_v_p(self):
        # code的最终向量是one-hot 变换w变换后的值
        code_vector = self.get_input(self.one_hot_input)
        # visit的向量是visit层的h
        visit_vector = self.get_visit_h()
        # patient的向量是最后一个visit的h
        patient_vector = self.get_patient_vector()
        return code_vector, visit_vector, patient_vector



# 获取数据集中每个病人的visit的表达和patient表达
    def TGRU_visit_cell2(self, prev_h, concat_input):
        # batch_size x input_dim+info_dim+info_dim+1
        batch_size = tf.shape(concat_input)[0]

        # 时间
        t = tf.slice(concat_input, [0, 0], [batch_size, 1])
        # visit_info
        info = tf.slice(concat_input, [0, 1], [batch_size, self.info_dim])
        # 输入
        x = tf.slice(concat_input, [0, 1+self.info_dim], [batch_size, self.output_dim2])

        # 把输入和病人信息放在一个向量中
        visit_x = tf.concat([x, info], 1)

        # visit嵌入向量经过GRU输出
        ft = self.map_elapse_time(t, self.vhidden_dim)
        z = tf.sigmoid(tf.matmul(visit_x, self.Wz_v) + tf.matmul(prev_h, self.Uz_v) + self.bz_v)
        r = tf.sigmoid(tf.matmul(visit_x, self.Wr_v) + tf.matmul(prev_h, self.Ur_v) + self.br_v)
        d = tf.matmul(ft, self.Wd_v)
        h_ = tf.multiply(d, prev_h)
        h_canditate = tf.sigmoid(tf.matmul(visit_x, self.Wh_v) + tf.matmul(r * h_, self.Uh_v) + self.bh_v)
        current_h = h_ - z * h_ + z * h_canditate

        return current_h
    # 备注的矩阵大小是假设一个用户的code有39个，2个visit，总共的code种类259个
    def get_encoder1_h_forVector(self,no):  # Returns all hidden states for the samples in a batch
        convert_input = self.get_input(tf.one_hot(indices=self.code_inputindex[no], depth=self.one_hot_input_dim,axis=1))
        batch_size = 1

        # (39, 1, 259)
        scan_input = tf.transpose([convert_input], perm=[1, 0, 2])  # scan input is [seq_length x batch_size x input_dim]
        # (39, 1)
        scan_time = tf.transpose([tf.to_float(self.code_time[no])])  # scan_time [seq_length x batch_size]
        # (1, 150)
        initial_hidden = tf.zeros([batch_size, self.hidden_dim1], tf.float32)
        # (39, 1, 1) make scan_time [seq_length x batch_size x 1]
        scan_time = tf.reshape(scan_time, [tf.shape(scan_time)[0], tf.shape(scan_time)[1], 1])
        # (39, 1, 260)
        concat_input = tf.concat([scan_time, scan_input], 2)  # [seq_length x batch_size x input_dim+1]

        encoder1_h = tf.scan(self.TGRU_encoder_cell1, concat_input, initializer=initial_hidden, name='encoder1_h_forVector')
        return encoder1_h

    def get_encoder2_h_forVector(self,no):  # Returns all hidden states for the samples in a batch
        # (39, 1, 150)
        encoder1_h = self.get_encoder1_h_forVector(no)
        # (39, 1, 150)
        encoder1_output = tf.map_fn(self.get_output1, encoder1_h)
        batch_size = 1
        # (39, 1)
        scan_time = tf.transpose([tf.to_float(self.code_time[no])])  # scan_time [seq_length x batch_size]
        # (1, 150)
        initial_hidden = tf.zeros([batch_size, self.hidden_dim2], tf.float32)
        # (39, 1, 1) # make scan_time [seq_length x batch_size x 1]
        scan_time = tf.reshape(scan_time, [tf.shape(scan_time)[0], tf.shape(scan_time)[1], 1])
        # (39, 1, 151)
        concat_input = tf.concat([scan_time, encoder1_output], 2)  # [seq_length x batch_size x input_dim+1]

        encoder2_h = tf.scan(self.TGRU_encoder_cell2, concat_input, initializer=initial_hidden, name='encoder2_h_forVector')
        return encoder2_h

    def get_one_visit_vector(self,no):
        encoder2_h = self.get_encoder2_h_forVector(no)
        encoder2_output = tf.map_fn(self.get_output2, encoder2_h)
        batch_size= 1
        n = 0
        x=[]
        for i in range(self.patient_visit_num[no][0]):
            #(1, 150)
            visit_input=tf.zeros([batch_size, self.output_dim2], dtype=tf.float32)
            for j in range(self.patient_visit_length[no][i]):
                visit_input+=encoder2_output[n]
                n=n+1
            x.append(visit_input/self.patient_visit_length[no][i])
        # (2, 1)
        scan_time = tf.transpose([tf.to_float(self.visit_time[no])])  # scan_time [seq_length x batch_size]
        # (2, 1, 2) make info [batch_size x seq_length x info_dim] --> [seq_length x batch_size x info_dim]
        scan_info = tf.transpose(tf.to_float([self.visit_info[no]]), [1, 0, 2])
        # (2, 1, 1)make scan_time [seq_length x batch_size x 1]
        scan_time = tf.reshape(scan_time, [tf.shape(scan_time)[0], tf.shape(scan_time)[1], 1])
        # (2, 1, 153)[seq_length x batch_size x input_dim+info_dim+1]
        concat_input = tf.concat([scan_time, scan_info, x], 2)
        # (1,100)
        initial_hidden = tf.zeros([batch_size, self.vhidden_dim], tf.float32)

        visit_vector = tf.scan(self.TGRU_visit_cell2, concat_input, initializer=initial_hidden, name='visit')
        patient_vector = tf.reverse(visit_vector,[0])[0, :, :]
        return visit_vector,patient_vector

    def get_all_visit_vector(self):
        visits_vector=[]
        patients_vector=[]
        for no in range(self.patient_num):
            visit_vector, patient_vector=self.get_one_visit_vector(no)
            visits_vector.append(visit_vector)
            patients_vector.append(patient_vector)
        return visits_vector,patients_vector


In [19]:
# batch生成器
# bacth中有cutting和padding，由于长度不同，因此开始学习的时候可以每个数据单独学习
# 生成的batch里的病人visit和code的个数一样
def batch_generator(outFile, visit_num, code_num, patient_num, batch_size):
    patient_code_file=open(outFile + '/patient_code' + '.seqs','rb')
    patient_code=pickle.load(patient_code_file)
    
    codespatientsinfo_file=open(outFile + '/codespatientsinfo' + '.seqs','rb')
    codespatientsinfo=pickle.load(codespatientsinfo_file)
    
    visit_delt_dates_file=open(outFile + '/visit_delt_dates' + '.seqs','rb')
    visit_delt_dates=pickle.load(visit_delt_dates_file)
    
    code_delt_dates_file=open(outFile + '/code_delt_dates' + '.seqs','rb')
    code_delt_dates=pickle.load(code_delt_dates_file)
    
    visits_num_file=open(outFile + '/visits_num' + '.seqs','rb')
    visits_num=pickle.load(visits_num_file)
    
    codes_num_file=open(outFile + '/codes_num' + '.seqs','rb')
    codes_num=pickle.load(codes_num_file)
    
    patient_code_file.close()
    codespatientsinfo_file.close()
    visit_delt_dates_file.close()
    code_delt_dates_file.close()
    visits_num_file.close()
    codes_num_file.close()

    batch_patient_code=[]
    batch_code_delt_dates=[]
    batch_codespatientsinfo=[]
    batch_visit_delt_dates=[]

    
    # padding and cutting
    for i in range(batch_size):
        j=random.randint(0, patient_num-1)
        if visits_num[j][0]== visit_num:
           # print 'pick',j
            code=[]
            code_date=[]
            code_info=[]
            for k in range(visit_num):
                if  codes_num[j][k]> code_num:
                    if k==0: 
                        code.extend(patient_code[j][0:code_num])
                        code_date.extend(code_delt_dates[j][0:code_num])
                        code_info.extend(codespatientsinfo[j][0:code_num])
                    else:
                        start=k*codes_num[j][k-1]
                        code.extend(patient_code[j][start :start+code_num])
                        code_date.extend(code_delt_dates[j][start:start+code_num])
                        code_info.extend(codespatientsinfo[j][start:start+code_num])
                elif codes_num[j][k] < code_num:
                    if k==0:
                        code.extend(patient_code[j][:codes_num[j][k]])
                        code.extend([0]*(code_num-codes_num[j][k]))
                        code_date.extend(code_delt_dates[j][:codes_num[j][k]])
                        code_date.extend([0]*(code_num-codes_num[j][k]))
                        code_info.extend(codespatientsinfo[j][:codes_num[j][k]])
                        code_info.extend([[0,0]]*(code_num-codes_num[j][k]))
                    else:
                        start2=0
                        for n in range(k):
                            start2+=codes_num[j][n]
                        code.extend(patient_code[j][start2: start2+codes_num[j][k]])
                        code.extend([0]*(code_num-codes_num[j][k]))
                        code_date.extend(code_delt_dates[j][start2:start2+codes_num[j][k]])
                        code_date.extend([0]*(code_num-codes_num[j][k]))
                        code_info.extend(codespatientsinfo[j][start2:start2+codes_num[j][k]])
                        code_info.extend([[0,0]]*(code_num-codes_num[j][k]))
                else:
                    code.extend(patient_code[j])
                    code_date.extend(code_delt_dates[j])
                    code_info.extend(codespatientsinfo[j])
            batch_patient_code.append(code)
            batch_code_delt_dates.append(code_date)  
            batch_codespatientsinfo.append(code_info)
                    
        
        # 考虑随机生成的时候可能有的数据始终选不到，因此还要有一个顺序生成
    return batch_patient_code, batch_codespatientsinfo, batch_code_delt_dates

In [3]:
# 训练参数
learning_rate = 1e-3
iters = 20

# 网络参数
info_dim=2
one_hot_input_dim=259 # 数据len(types)=259, one-hot的长度应该是259
input_dim = 200 
hidden_dim1 =150
output_dim1=150
hidden_dim2=150
output_dim2=150
vhidden_dim=100
voutput_dim=150
hidden_dim3=150
output_dim3=150
hidden_dim4=150
output_dim=259

# 生成batch数据参数
outFile='SEQ'
checkpoint_dir='MODEL'
result_dir='RESULT'
batch_size=4
visit_num=2 #visit_num=random.randint() 
code_num=10 #code_num=random.randint() # code_num=visit_length
patient_num=14 #病人的总个数

# 数据集获取
patient_code_file=open(outFile + '/patient_code' + '.seqs','rb')
patient_code=pickle.load(patient_code_file)

code_delt_dates_file=open(outFile + '/code_delt_dates' + '.seqs','rb')
code_delt_dates=pickle.load(code_delt_dates_file)

visit_delt_dates_file=open(outFile + '/visit_delt_dates' + '.seqs','rb')
visit_delt_dates=pickle.load(visit_delt_dates_file)

visits_num_file=open(outFile + '/visits_num' + '.seqs','rb')
visits_num=pickle.load(visits_num_file)

codes_num_file=open(outFile + '/codes_num' + '.seqs','rb')
codes_num=pickle.load(codes_num_file)

visitspatientsinfo_file=open(outFile + '/visitspatientsinfo' + '.seqs','rb')
visitspatientsinfo=pickle.load(visitspatientsinfo_file)

codespatientsinfo_file=open(outFile + '/codespatientsinfo' + '.seqs','rb')
codespatientsinfo=pickle.load(codespatientsinfo_file)

patient_code_file.close()
code_delt_dates_file.close()
visit_delt_dates_file.close()
visits_num_file.close()
codes_num_file.close()
visitspatientsinfo_file.close()
codespatientsinfo_file.close()

In [75]:
# 实例化网络
mtgruae = MTGRU_AE(visit_num, code_num, one_hot_input_dim, input_dim, info_dim, output_dim, output_dim1, output_dim2,output_dim3, voutput_dim, vhidden_dim,hidden_dim1, hidden_dim2, hidden_dim3, hidden_dim4,visits_num,codes_num,visit_delt_dates,visitspatientsinfo,patient_num,patient_code,code_delt_dates)


In [76]:
# 目标，使用交叉熵/ SME
loss = mtgruae.get_cross_loss()
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)

##### 多次走iters

In [83]:
init = tf.global_variables_initializer()
sess= tf.Session() 
sess.run(init)
Loss = np.zeros(iters)

In [113]:
# 生成batch训练
for i in range(iters):
    LOSS = 0
    for i in range(10): #学习完整个数据库需要的次数,要按照实际
        # 生成batch
        xindex, info, t = batch_generator(outFile, visit_num, code_num, patient_num,batch_size)
        if len(xindex)==0: continue
        _, L = sess.run([optimizer, loss], feed_dict={mtgruae.inputindex: xindex, mtgruae.time: t, mtgruae.info:info})
        LOSS += L
    Loss[i] = LOSS / 10
    print('Loss: %f' %(Loss[i]))

Loss: 0.042512
Loss: 0.036165
Loss: 0.038223
Loss: 0.031517
Loss: 0.033212
Loss: 0.036574
Loss: 0.035183
Loss: 0.032555
Loss: 0.038279
Loss: 0.039314
Loss: 0.031534
Loss: 0.033046
Loss: 0.037364
Loss: 0.035777
Loss: 0.032443
Loss: 0.031197
Loss: 0.031380
Loss: 0.037866
Loss: 0.032057
Loss: 0.034172


In [114]:
# 获取code的向量
# Wi最好relu
Wi=sess.run(mtgruae.Wi)

In [115]:
#获取数据库中所有visit的patient的表达/新的用户的表达
visit_vector, patient_vector = sess.run(mtgruae.get_all_visit_vector())

In [116]:
patient_vector[1]

array([[9.9866611e-01, 9.9999934e-01, 1.7689591e-02, 3.2384160e-01,
        9.9999988e-01, 5.5573728e-05, 9.5419500e-07, 2.2339211e-06,
        4.2910255e-03, 2.0026885e-08, 7.4299240e-01, 9.9999291e-01,
        6.9143327e-09, 3.8227860e-08, 6.3482730e-10, 9.9961030e-01,
        9.9999750e-01, 3.2532445e-01, 9.9999934e-01, 6.8571548e-10,
        6.4419675e-01, 9.9998844e-01, 3.6672046e-04, 5.8583860e-10,
        9.9999875e-01, 3.1186223e-02, 7.0244282e-01, 9.9999821e-01,
        8.4483025e-11, 1.3777556e-07, 9.3763759e-03, 5.3719965e-08,
        9.9999762e-01, 5.8757067e-02, 1.0026448e-01, 9.9491370e-01,
        9.9999982e-01, 9.9999869e-01, 9.9999529e-01, 5.2264112e-04,
        1.6715336e-03, 4.5981199e-02, 2.5390431e-02, 6.2860539e-03,
        5.1270626e-02, 9.9999732e-01, 8.7920511e-03, 8.9708613e-03,
        6.9949059e-03, 1.8713091e-02, 1.5841199e-08, 9.9738002e-01,
        5.3471480e-03, 5.0482871e-03, 3.7583418e-02, 9.8355345e-02,
        5.8662355e-02, 1.9999946e-02, 3.8527813e

In [109]:
# 保存模型
saver = tf.train.Saver()
saver.save(sess, checkpoint_dir + '/model.ckpt') 

'MODEL/model.ckpt'

In [120]:
# 保存code visit patient向量
pickle.dump(Wi, open(result_dir + '/code_vector' + '.seqs', 'wb'), -1) 
pickle.dump(visit_vector, open(result_dir + '/visit_vector' + '.seqs', 'wb'), -1) 
pickle.dump(patient_vector, open(result_dir + '/patient_vector' + '.seqs', 'wb'), -1) 

In [81]:
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    Loss = np.zeros(iters)
    for i in range(iters):
        Ll = 0
        for i in range(10): #学习完整个数据库需要的次数,要按照实际
            # 生成batch,
            xindex, info, t = batch_generator('SEQ', visit_num, code_num, patient_num,batch_size)
            if len(xindex)==0: continue
            _, L = sess.run([optimizer, loss], feed_dict={mtgruae.inputindex: xindex, mtgruae.time: t, mtgruae.info:info})
            Ll += L
        Loss[i] = Ll / 10
        print('Loss: %f' %(Loss[i]))

Loss: 5.212439
Loss: 4.750262
Loss: 4.520190
Loss: 4.613481
Loss: 4.385721
Loss: 4.476635
Loss: 4.364641
Loss: 4.281379
Loss: 4.243845
Loss: 4.033826
Loss: 4.233166
Loss: 4.160229
Loss: 4.152536
Loss: 3.949671
Loss: 3.995037
Loss: 4.049084
Loss: 4.197363
Loss: 3.865101
Loss: 4.103941
Loss: 3.926943
