In [1]:
##---------------------------------------
## Training based on posterior and MMD cost function
##_______________________________________

import argparse
#import _pickle as pickle
import pickle
import math
import numpy as np
import random
import tensorflow as tf
import pdb
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
tf.compat.v1.disable_eager_execution()


"""
Give the training images from the MNIST dataset
"""
def loadMNIST():

    # Downloaded from http://deeplearning.net/data/mnist/mnist.pkl.gz
    with open('mnist.pkl', 'rb') as f:    
        train_data, val_data, test_data = pickle.load(f,encoding="bytes")
    train_x, train_y = train_data

    return train_x
"""
Give the training images from the histopathology dataset
"""
def loadHPTLOG():

    # Downloaded from https://github.com/jmtomczak/vae_householder_flow/tree/master/datasets/histopathologyGray
    with open('histopathology.pkl', 'rb') as f:    
        train_HPTLOG = pickle.load(f,encoding="bytes")
    right=train_HPTLOG.values()
    right=list(right)        #the sample numbers in realstic samples from generator
    right=np.array(right)
    HPTLOG=right[1]
    HPTLOG=np.array(HPTLOG)
    HPTLOG=HPTLOG.reshape([6800,784])

    return HPTLOG

"""
Give the training images from the MRI_meningioma dataset
"""
def loadMRImeningioma():

    #Original dataset is given by: https://www.kaggle.com/datasets/masoudnickparvar/brain-tumor-mri-dataset
    data=np.loadtxt('Training_MRI_meningioma.txt')
    data_min=np.min(data, keepdims=True)
    data_max=np.max(data, keepdims=True)
    scale_data=(data-data_min)/(data_max-data_min)

    return scale_data
"""
Give the training images from the MRI_meningioma dataset
"""
def loadMRI():

    #Original dataset is given by: https://www.kaggle.com/datasets/masoudnickparvar/brain-tumor-mri-dataset
    data=np.loadtxt('Training_MRI_total.txt')
    data_min=np.min(data, keepdims=True)
    data_max=np.max(data, keepdims=True)
    scale_data=(data-data_min)/(data_max-data_min)

    return scale_data
"""
Give the training images from the cropped LFW dataset
"""
def loadLFW():

    # 32x32 version of grayscale cropped LFW
    # Original dataset is given by: http://conradsanderson.id.au/lfwcrop/
    return np.load('lfw.npy')


"""
Posterior of real data (x)
"""
def x_pos( x):
    from scipy import stats
    from numpy.random import choice
    # batch size for the training
    batch_size = 1000
    input_dim=784
    # generate images from the provided uniform samples
    n=1000 # Determine the number of terms in DP approximation (N)
    d=input_dim
    m=batch_size
    a =.05
    g2_pos=np.zeros((1, n)); tstar2_pos =np.zeros((n, d))
    g22_pos= np.zeros((1, n)); pp2_pos= np.zeros((1, n))
    y21_pos=np.random.exponential(scale=1.0, size=n+1)

    for i in range(0, n):
        g2_pos[0,i]=np.sum(y21_pos[0:i+1])/np.sum(y21_pos)
        g22_pos[0,i]=stats.gamma.ppf(1-g2_pos[0,i], (a+m)/n, 1)-1
        u=np.random.uniform(low=0.0, high=1.0, size=1)
        if u<a/(a+m):
            tstar2_pos[i-1 ,0:d]=np.random.multivariate_normal(np.zeros(d), np.identity(d), 1)
        else:
            r2=np.random.randint(m, size=1)
            tstar2_pos[i-1 ,0:d]=x[r2]


    p22_pos=g22_pos/np.sum(g22_pos)
    v2_pos=choice(np.arange(0, n, 1, dtype=int), n,p=p22_pos[0])
    tstarDP2_pos=tstar2_pos[v2_pos]

    X_star2_pos=tstarDP2_pos; J_star2_pos=p22_pos
    return X_star2_pos,J_star2_pos

batch_size = 1000
input_dim    = 784
image_side   = 28
num_examples = 50000
train_x      = loadMNIST()
batch_indices = np.random.randint(num_examples, size = batch_size)
batch_x       = train_x[batch_indices]
posterior     =x_pos( batch_x)
batch_pos     =posterior[0]
J_star2_pos   =-posterior[1][0]

"""
Return a TF variable with zeros of provided shape
"""
def zeros(shape):

    return tf.Variable(tf.zeros(shape))

"""
Return a TF variable with numbers drawn from a normal distribution of zero mean
and given standard deviation
"""
def normal(shape, std_dev):

    return tf.Variable(tf.random.normal(shape, stddev = std_dev))

class ReLULayer():

    """
    Initialize layer object with the given input and output dimensions

    input_dim:  Dimension of inputs to the layer
    output_dim: Dimension of outputs of the layer
    """
    def __init__(self, input_dim, output_dim):

        # initialize weights and biases for the layer
        self.W = normal([input_dim, output_dim], 1.0 / math.sqrt(input_dim))
        self.b = zeros([output_dim])

    """
    Forward propagation in the layer

    x: Input to the layer
    """
    def forward(self, x):

        return tf.nn.relu(tf.matmul(x, self.W) + self.b)

class SigmoidLayer():

    """
    Initialize layer object with the given input, output dimensions and dropout
    retention probabilities

    input_dim:    Dimension of inputs to the layer
    output_dim:   Dimension of outputs of the layer
    dropout_prob: Fraction of dropout retention in the layer
    """
    def __init__(self, input_dim, output_dim, dropout_prob = 1.0):

        # initialize weights and biases for the layer
        self.W = normal([input_dim, output_dim], 1.0 / math.sqrt(input_dim))
        self.b = zeros([output_dim])

        # store the dropout retention probability for later use
        self.dropout_prob = dropout_prob

    """
    Forward propagation in the layer

    x: Input to the layer
    """
    def forward(self, x):

        return tf.sigmoid(tf.matmul(tf.nn.dropout(x, rate=1 - (self.dropout_prob)),
                          self.W) + self.b)

class DataSpaceNetwork():

    """
    Initialize network object with the given dimensions and batch size

    dimensions: Dimensions of the all the layers of the network, including
                input and output
    batch_size: Number of training examples taken in the batch
    """
    def __init__(self, dimensions, batch_size):

        # store 'dimensions' and 'batch_size' for later use
        self.dimensions = dimensions
        self.batch_size = batch_size

        # store the layers as a list
        self.layers = []

        # all the layers except the last one is 'ReLU'
        for dim_index in range(len(dimensions)-2):
            self.layers.append(ReLULayer(dimensions[dim_index],
                                         dimensions[dim_index+1]))

        # last layer is 'Sigmoid' as we need the outputs to be in [0, 1]
        self.layers.append(SigmoidLayer(dimensions[dim_index+1],
                                        dimensions[dim_index+2]))

    """
    Forward propagation of the network

    x: Input batch of samples from the uniform
    """
    def forward(self, x):

        # initialize the first 'hidden' layer to the input
        h = x

        # for all the layers propagate the activation forward
        # all layers have the 'forward()' method
        for dim_index in range(len(self.dimensions)-1):
            h = self.layers[dim_index].forward(h)

        return h
    


    """
    Scale column for the MMD measure

    num_gen:  Number of samples to be generated in one pass, 'N' in the paper
    num_orig: Number of samples taken from dataset in one pass, 'M' in the paper
    """
    def makeScaleMatrix(self, num_gen, num_orig):

        # first 'N' entries have '1/N', next 'M' entries have '-1/M'
        s1 =  tf.constant(1.0 / num_gen, shape = [num_gen, 1])
        s2 = -tf.constant(1.0 / num_orig, shape = [num_orig, 1])

        return tf.concat([s1, s2], axis=0)
    
    

    """
    Calculates cost of the network, which is square root of the mixture of 'K'
    RBF kernels

    x:       Batch from the dataset
    samples: Samples from the uniform distribution
    sigma:   Bandwidth parameters for the 'K' kernels
    """
    def computeLoss(self, x, samples, weight, sigma = [2,5,10,20,40,80]):


        # generate images from the provided uniform samples
        gen_x = self.forward(samples)


        # concatenation of the generated images and images from the dataset
        # first 'N' rows are the generated ones, next 'M' are from the data
        X = tf.concat([gen_x, x], axis=0)

        # dot product between all combinations of rows in 'X'
        XX = tf.matmul(X, tf.transpose(a=X))

        # dot product of rows with themselves
        X2 = tf.reduce_sum(input_tensor=X * X, axis=1, keepdims = True)

        # exponent entries of the RBF kernel (without the sigma) for each
        # combination of the rows in 'X'
        # -0.5 * (x^Tx - 2*x^Ty + y^Ty)
        exponent = XX -.5*  X2 -.5* tf.transpose(a=X2)

        # scaling constants for each of the rows in 'X'
        
        #J1=tf.reshape(J_star1_pos,[batch_size,1])
        J2=tf.reshape(weight,[batch_size,1])
        s1 =  tf.constant(1.0 / self.batch_size, shape = [self.batch_size, 1])
        s =tf.concat([s1,tf.cast(J2, tf.float32)], axis=0) #self.makeScaleMatrix(self.batch_size, self.batch_size)

        # scaling factors of each of the kernel values, corresponding to the
        # exponent values
        S = tf.matmul(s, tf.transpose(a=s))

        loss = 0

        # for each bandwidth parameter, compute the MMD value and add them all
        for i in range(len(sigma)):

            # kernel values for each combination of the rows in 'X' 
            kernel_val = tf.exp(1.0 / sigma[i] * exponent)
            loss += tf.reduce_sum(input_tensor=S * kernel_val)

        return tf.sqrt(loss)
def generateFigure(samples, num_rows, num_cols, image_side, file_name):

    # initialize the figure object
    figure, axes = plt.subplots(nrows = num_rows, ncols = num_cols)

    index = 0
    # take the first 'num_rows * num_cols' samples from the provided batch
    for axis in axes.flat:
        image = axis.imshow(samples[index, :].reshape(image_side, image_side),
                            cmap = plt.cm.gray, interpolation = 'nearest')
        axis.set_frame_on(False)
        axis.set_axis_off()
        index += 1 

    # save the figure
    figure.savefig(file_name)
def trainDataSpaceNetwork(dataset):

    # batch size for the training
    batch_size = 1000
    
    # parameters and training set for MNIST
    if dataset == 'mnist':
        input_dim    = 784
        image_side   = 28
        num_examples = 50000
        train_x      = loadMNIST()

    # parameters and training set for LFW
    elif dataset == 'lfw':
        input_dim    = 1024
        image_side   = 32
        num_examples = 13000
        train_x      = loadLFW()
    # parameters and training set for histopathology
    elif dataset == 'histopathology':
        input_dim    = 784
        image_side   = 28
        num_examples = 6800
        train_x      = loadHPTLOG()
    elif dataset == 'MRImeningioma':
        input_dim    = 2500
        image_side   = 50
        num_examples = 1339
        train_x      = loadMRImeningioma()
    elif dataset == 'MRI':
        input_dim    = 2500
        image_side   = 50
        num_examples = 5712
        train_x      = loadMRI()

    # dimensions of the moment matching network
    data_space_dims = [10, 64, 256, 256, input_dim]

    # get a DataSpaceNetwork object
    data_space_network = DataSpaceNetwork(data_space_dims, batch_size)

    # placeholders for the data batch and the uniform samples respectively
    x       = tf.compat.v1.placeholder("float", [batch_size, input_dim])
    samples = tf.compat.v1.placeholder("float", [batch_size, data_space_dims[0]])
    weight  = tf.compat.v1.placeholder("float", [batch_size, 1])
    
    # cost of the network, and optimizer for the cost
    cost      = data_space_network.computeLoss(x, samples, weight)
    optimizer = tf.compat.v1.train.AdamOptimizer().minimize(cost)

    # generator for the network
    generate = data_space_network.forward(samples)

    # initalize all the variables in the model
    init = tf.compat.v1.initialize_all_variables()
    sess = tf.compat.v1.Session()
    sess.run(init)

    # number of batches to train the model on, and frequency of printing out the
    # cost
    num_iterations  = 40001
    iteration_break = 1000

    for i in range(num_iterations):

        # sample a random batch from the training set
        batch_indices = np.random.randint(num_examples, size = batch_size)
        batch_x       = train_x[batch_indices]
        posterior     = x_pos( batch_x)
        batch_pos     = posterior[0]
        J_star2_pos   = -posterior[1][0].reshape(-1,1)
        #batch_x       = train_x[batch_indices]
        batch_uniform = np.random.uniform(low = -1.0, high = 1.0,
            size = (batch_size, data_space_dims[0]))

        
        # print out the cost after every 'iteration_break' iterations
        if i % iteration_break == 0:
            curr_cost = sess.run(cost, feed_dict = {samples: batch_uniform,
                                                    x: batch_pos, weight: J_star2_pos})
            print('Cost at iteration ' + str(i+1) + ': ' + str(curr_cost))

        # optimize the network
        sess.run(optimizer, feed_dict = {samples: batch_uniform, x: batch_pos, 
                                         weight: J_star2_pos})

    # parameters for figure generation
    num_rows = 10; num_cols = 10

    # generate samples from the trained network
    batch_uniform = np.random.uniform(low = -1.0, high = 1.0,
        size = (batch_size, data_space_dims[0]))
    gen_samples   = sess.run(generate, feed_dict = {samples: batch_uniform})

    # generate figure of generated samples
    file_name = dataset + '_data_space_MMD_pos40000.png'
    figg=generateFigure(gen_samples, num_rows, num_cols, image_side, file_name)
    return gen_samples,figg

"""
Train code space network on the given dataset

dataset: Either 'mnist' or 'lfw', indicating the dataset
"""    
out=trainDataSpaceNetwork('mnist')


KeyboardInterrupt: 