In [None]:
class VNet():
    def __init__(self, layers, unit_type="bin", noise_std = 0.001, model_path='./models/dbm.ckpt'):
        tf.reset_default_graph()
        self.unit_type = unit_type
        self.layers_nums = layers
        self.layers = []
        self.noise_std = noise_std
        self.input = tf.placeholder("bool" if self.unit_type == "bin" else 'float', [None, None], name="input")
        self.model_path = model_path
        self.loaded_model = False
        self.vae_initialized = False
        
        i = 0
        for unit_count in layers:
            self.layers.append({
                "state" : None,
                "b" : tf.Variable(tf.zeros([unit_count]), name="b_"+str(i))
            })
            i+=1
        self.W = []
        
        for i in range(len(layers)-1):
            self.W.append(tf.Variable(tf.random_normal((layers[i], layers[i+1]), mean=0.0, stddev=0.01), name="W_"+str(i)))
        
        self.tf_saver = tf.train.Saver()
        
    def __gaussian_noise_layer(self, x):
        noise = tf.random_normal(shape=tf.shape(x), mean=0.0, stddev=self.noise_std, dtype=tf.float32) 
        return x + noise
    
    def sample_forward(self, start_from=0, steps_num=None):
        if(steps_num == None or steps_num > len(self.layers)):
            steps_num = len(self.layers)-1
        for i in range(start_from, start_from + steps_num):
            mul = tf.matmul(self.layers[i]["state"], self.W[i]) + self.layers[i+1]["b"]
            
            if(self.unit_type == "bin") :
                self.layers[i+1]["state"] = tf.nn.relu(tf.sign(tf.sigmoid(mul) 
                                                               - np.random.rand(tf.shape(mul)[0], tf.shape(mul)[1])))
            else:
                 self.layers[i+1]["state"] = tf.nn.relu(self.__gaussian_noise_layer(mul))
                    
        return self.layers[i+1]["state"]
    
    def sample_back(self, start_from=None, steps_count=None):
        if(steps_count == None):
            steps_count = len(self.layers)-1
            start_from = len(self.layers)-1
        i = start_from
        while(i > start_from - steps_count):
            mul = tf.matmul(self.layers[i]["state"], tf.transpose(self.W[i-1])) + self.layers[i-1]["b"]
            if(self.unit_type == "bin"):
                self.layers[i-1]["state"] = tf.nn.relu(tf.sign(tf.sigmoid(mul) 
                                                               - np.random.rand(tf.shape(mul)[0], tf.shape(mul)[1])))
            else:
                 self.layers[i-1]["state"] = self.__gaussian_noise_layer(mul)
            i -= 1
        return self.layers[i]["state"]
    
    def __getNextBatch(self, batch_size):
        if(self.train_set):
            if(callable(self.train_set)):
                gen = self.train_set(batch_size)
                for result in gen:
                    yield result
                    
    def prepare_train_set(self, batch_size, epochs_count):
        if(callable(self.train_set)):
            return self.train_set(self.files, int(self.layers_nums[0]), batch_size, epochs_count, True)
    
    def train_dbm(self, train_set, batch_size, learning_rate, epochs_count, decrease_noise = 0, pcd = 1, depth=None):
        if(depth == None or depth > len(self.layers)-1):
            depth = len(self.layers)
        self.train_set = train_set
        if(self.loaded_model == False):
            self.tf_sess = tf.Session()
            self.tf_sess.run(tf.global_variables_initializer())
        self.train_set = self.prepare_train_set(batch_size, epochs_count*(depth-1))
        print("start training "+str(learning_rate))
        for i in range(depth-1):
            self.noise_std = 0.001
            inp = self.train_set.__next__()
            for j in  range(epochs_count):
                self.noise_std = self.noise_std - decrease_noise*self.noise_std
                self.layers[0]["state"] = self.input
                h0_state = self.sample_forward(0, i+1)
                v0_state = self.layers[i]["state"]
                positive = tf.matmul(tf.transpose(v0_state), h0_state)
                for k in range(pcd):                    
                    v1_state = self.sample_back(i+1, 1)
                    h1_state = self.sample_forward(i, 1)
                    negative = tf.matmul(tf.transpose(v1_state), h1_state)
                    w_update = self.W[i].assign_add(learning_rate*(positive - negative))
                    v_loss = tf.reduce_mean(v0_state - v1_state, 0)
                    vb_update = self.layers[i]["b"].assign_add(learning_rate*(v_loss))
                    hb_update = self.layers[i+1]["b"].assign_add(learning_rate*(tf.reduce_mean(h0_state - h1_state, 0)))
                    w_upd, bv_upd, bh_upd, loss = self.tf_sess.run([w_update, vb_update, hb_update, v_loss], feed_dict={
                        self.input : norm(inp)
                    })
                if(j % 100 == 0):
                    print("epoch: "+str(j)+" layer:"+str(i)+" loss:")
                    print(np.mean(np.power(loss, 2)))
        self.train_set.__next__()
        self.tf_saver.save(self.tf_sess, self.model_path)
        
    def load_model(self, path=None):
        self.tf_sess = tf.Session()
        self.tf_sess.run(tf.global_variables_initializer())
        self.tf_saver.restore(self.tf_sess, path if path != None else self.model_path)
        self.loaded_model = True
        
    def encode(self, data):
        self.layers[0]["state"] = self.input
        res = self.vae_recognize() if self.vae_initialized == True else self.sample_forward()
        with tf.Session() as self.tf_sess:
            self.tf_sess.run(tf.global_variables_initializer())
            self.tf_saver.restore(self.tf_sess, self.model_path)
            return self.tf_sess.run(res, feed_dict={
                self.input : data
            })
    
    def decode(self, data):        
        if(self.vae_initialized == True):
            self.z = self.input
            res = self.vae_generate()
        else:
            self.layers[-1]["state"] = self.input
            res = self.sample_back()
        with tf.Session() as self.tf_sess:
            self.tf_sess.run(tf.global_variables_initializer())
            self.tf_saver.restore(self.tf_sess, self.model_path)
            return self.tf_sess.run(res, feed_dict={
                self.input : data
            })
        
    def xavier_init(self, fan_in, fan_out, constant=1): 
        low = -constant*np.sqrt(6.0/(fan_in + fan_out)) 
        high = constant*np.sqrt(6.0/(fan_in + fan_out))
        return tf.random_uniform((fan_in, fan_out), 
                                 minval=low, maxval=high, 
                                 dtype=tf.float32)

    def init_vae(self, latent_size):
        print(self.vae_initialized)
        if(self.vae_initialized != True):
            print("init vae")
            self.latent_size = latent_size
            self.decode_layers = []
            for i in range(len(self.layers)-1):
                self.decode_layers.append({
                    "state" : None,
                    "b" : self.layers[i]["b"]
                })
                
            self.decode_W = []
            for i in range(len(self.layers)-1):
                self.decode_W.append(tf.transpose(self.W[i]))
            self.z_in_mean_W = self.xavier_init(self.layers_nums[-1], latent_size);
            self.z_in_mean_b = tf.Variable(tf.zeros([latent_size], dtype=tf.float32));
            self.z_out_mean_W = self.xavier_init(latent_size, self.layers_nums[-1]);
            self.z_out_mean_b = tf.Variable(tf.zeros([self.layers_nums[-1]], dtype=tf.float32));
            self.z_signa_squared_W = self.xavier_init(self.layers_nums[-1], latent_size);
            self.z_signa_squared_b = tf.Variable(tf.zeros([latent_size], dtype=tf.float32))
            self.class_w = self.xavier_init(latent_size, len(self.files));
            self.class_b = tf.Variable(tf.zeros([len(self.files)], dtype=tf.float32))
            self.vae_initialized = True
    
    def vae_recognize(self):
        encoded = self.sample_forward()
        self.z_mean = tf.matmul(encoded, self.z_in_mean_W) + self.z_in_mean_b
        self.z_signa_squared = tf.matmul(encoded, self.z_signa_squared_W) + self.z_signa_squared_b
        
        eps = tf.random_normal((self.batch_size, self.latent_size), 0, 1, 
                               dtype=tf.float32)
        # z = mu + sigma*epsilon
        return tf.add(self.z_mean, 
                        tf.multiply(tf.sqrt(tf.exp(self.z_signa_squared )), eps))
        
    def vae_generate(self):
        self.layers[-1]["state"] = tf.matmul(self.z, self.z_out_mean_W)+self.z_out_mean_b
        return self.sample_back()
        
    def train_vae(self, train_set, batch_size, learning_rate, epochs_count):
        self.train_set = train_set
        self.layers[0]["state"] = self.input
        self.batch_size = batch_size
        self.z = self.vae_recognize()
        reconstruction = self.vae_generate()
        reconstr_loss = tf.reduce_mean(tf.square(reconstruction-self.input))
        
        latent_loss = -0.0001 * tf.reduce_sum(1 + self.z_signa_squared 
                                           - tf.square(self.z_mean) 
                                           - tf.exp(self.z_signa_squared), 1)
        self.cost = tf.reduce_mean(reconstr_loss + latent_loss)   # average over batch
        mimmaze = tf.train.AdamOptimizer(learning_rate).minimize(self.cost)
        self.train_set = self.prepare_train_set(batch_size, epochs_count)
        print("start vae training");
        with tf.Session() as self.tf_sess:
            self.tf_sess.run(tf.global_variables_initializer())
            self.tf_saver.restore(self.tf_sess, self.model_path)
            for i in range(epochs_count):                
                if(i % 100 == 0):
                    cost, minimize = self.tf_sess.run([self.cost, mimmaze], feed_dict={
                        self.input : norm(self.train_set.__next__())
                    })
                    print("epoch: "+str(i)+" loss:"+str(cost))
                else:
                    self.tf_sess.run(mimmaze, feed_dict={
                        self.input : norm(self.train_set.__next__())
                    })
            self.train_set.__next__()
            self.tf_saver.save(self.tf_sess, self.model_path)
    
    def regression(self, train_set, batch_size, epoch_count):
        for i in range(len(self.layers_nums)):
            tf.stop_gradient(self.layers[i]["b"])
            if(i < len(self.layers_nums)-1):
                tf.stop_gradient(self.W[i])
        file_labels = tf.placeholder(tf.int64, [None], name="Labels")
        self.train_set = train_set
        if(self.loaded_model == False):
            self.load_model()
        self.train_set = self.prepare_train_set(batch_size, epoch_count*(len(self.layers)), True)
        for j in range(epoch_count):
            self.layers[0]["state"] = self.input
            res = self.sample_forward()
            mul = tf.matmul(res, self.rW) + self.rb
            cross_entropy = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(logits=mul, labels=file_labels))
            train_step = tf.train.GradientDescentOptimizer(0.05).minimize(cross_entropy)
            batch = self.train_set.__next__()
            err, step = self.tf_sess.run([cross_entropy, train_step], feed_dict = {
                self.input : norm(batch[0]),
                file_labels : batch[1]
            })
            if(j % 100 == 0):
                print("epoch: "+str(j)+" loss:")
                print(err)
        self.train_set.__next__()
        self.tf_saver.save(self.tf_sess, self.model_path)

    def transform(self, inp, target, k):
        self.layers[0]["state"] = self.input
        encoded = self.sample_forward()
        self.layers[-1]["state"] = encoded + k*tf.transpose(self.rW)[target]
        res = self.sample_back()
        if(self.loaded_model == False):
            self.load_model()
        enc, res = self.tf_sess.run([encoded, res], feed_dict={
            self.input : inp
        })
        print(self.tf_sess.run(tf.transpose(self.rW)[target]))
        return res