<a href="https://colab.research.google.com/github/RoozbehSanaei/deep-learning-notebooks/blob/master/hologan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
#@title Utilities
!git clone https://github.com/dimatura/binvox-rw-py
print("*")
import numpy as np
import tensorflow as tf
import tensorflow.contrib as tf_contrib
import tensorflow.contrib.slim as slim
import os
# from tools.ops import *
import math
import glob
import scipy.io
import time
from __future__ import division
import tarfile
import zlib
import io
from PIL import Image
import pprint
import random
import scipy.misc
import sys
import glob
sys.path.insert(1, 'binvox-rw-py/')
import binvox_rw

def get_weight(weight_name, weight_dict):
    if weight_dict is None:
        return None
    else:
        return weight_dict.get(weight_name)  # returns None if name is not found in dictionary

def res_block_3d(input, out_channels=64, scope = 'res_block', kernel=[3, 3, 3], prob = 0.5,  stride=[1, 1, 1], weight_dict=None):
    with tf.variable_scope(scope):
        net = tf.nn.relu(conv3d(input, out_channels, kernel_size=kernel, stride=stride, pad="SAME", scope="con1_3X3",
                                weight_initializer=get_weight(scope + 'con1_3X3_weights', weight_dict),
                                bias_initializer=get_weight(scope + 'con1_3X3_biases', weight_dict),
                                weight_initializer_type=tf.contrib.layers.xavier_initializer()))
        # net = tf.nn.dropout(net, keep_prob(prob, is_training))
        net = conv3d(net, out_channels, kernel_size=kernel, stride=stride, pad="SAME", scope="conv2_3x3",
                     weight_initializer=get_weight(scope + 'conv2_3x3_weights', weight_dict),
                     bias_initializer=get_weight(scope + 'conv2_3x3_biases', weight_dict),
                     weight_initializer_type=tf.contrib.layers.xavier_initializer())
        # net = tf.nn.dropout(net, keep_prob(prob, is_training))
    return tf.add(tf.cast(net, tf.float32), tf.cast(input, tf.float32))


def res_block_2d(input, out_channels=64, scope = 'res_block', kernel=[3, 3], prob = 0.5,  stride=[1, 1], weight_dict=None):
    with tf.variable_scope(scope):
        net = tf.nn.relu(conv2d(input, out_channels, kernel_size=kernel, stride=stride, pad="SAME", scope="con1_3X3",
                                weight_initializer=get_weight(scope + 'con1_3X3_weights', weight_dict),
                                bias_initializer=get_weight(scope + 'con1_3X3_biases', weight_dict),
                                weight_initializer_type=tf.contrib.layers.xavier_initializer()))
        # net = tf.nn.dropout(net, keep_prob(prob, is_training))
        net = conv2d(net, out_channels, kernel_size=kernel, stride=stride, pad="SAME", scope="conv2_3x3",
                     weight_initializer=get_weight(scope + 'conv2_3x3_weights', weight_dict),
                     bias_initializer=get_weight(scope + 'conv2_3x3_biases', weight_dict),
                     weight_initializer_type=tf.contrib.layers.xavier_initializer())
        # net = tf.nn.dropout(net, keep_prob(prob, is_training))
    return tf.add(tf.cast(net, tf.float32), tf.cast(input, tf.float32))


def bias_variable(shape, bias_initializer=None, trainable=True):
    if bias_initializer is None :
        return tf.get_variable(name='biases', initializer=tf.constant(0.0, shape=shape), trainable = trainable)
    else:
        return tf.get_variable(name='biases', initializer=bias_initializer, trainable = trainable)


def _conv_init_vars(net, out_channels, filter_size, transpose=False):
    _, rows, cols, in_channels = [i.value for i in net.get_shape()]
    if not transpose:
        weights_shape = [filter_size, filter_size, in_channels, out_channels]
    else:
        weights_shape = [filter_size, filter_size, out_channels, in_channels]

    weights_init = tf.Variable(tf.truncated_normal(weights_shape, stddev=0.1, seed=1), dtype=tf.float32)
    return weights_init

def conv2d(input_, num_outputs, kernel_size=[4,4], stride=[1,1], pad='SAME', if_bias=True, trainable=True, reuse=False,
           scope='conv2d', weight_initializer=None, bias_initializer=None,
           weight_initializer_type=tf.random_normal_initializer(stddev=0.02)):
    print(scope)
    with tf.variable_scope(scope, reuse = reuse):
        if weight_initializer is None:
            print("Initializing weights")
            w = tf.get_variable(name='weights',
                                shape = kernel_size +  [input_.get_shape()[-1]] + [num_outputs],
                                initializer = weight_initializer_type,
                                dtype = tf.float32, trainable=trainable)
        else:
            print("Loading weights")
            w = tf.get_variable(name='weights',
                                initializer = weight_initializer,
                                dtype = tf.float32, trainable=trainable)

        conv = tf.nn.conv2d(input_, w,
                            padding = pad,
                            strides = [1] + stride + [1])

        if if_bias:
            if bias_initializer is None:
                print("Initializing biases")
                conv=conv+bias_variable([num_outputs], trainable=trainable)
            else:
                print("Loading biases")
                conv=conv+bias_variable([num_outputs], trainable=trainable, bias_initializer=bias_initializer)

        return conv


def conv2d_transpose(x, num_outputs, kernel_size = (4,4), stride= (1,1), pad='SAME', if_bias=True,
                     reuse=False, scope = "conv2d_transpose", trainable = True, weight_initializer=None,
                     bias_initializer=None, weight_initializer_type=tf.random_normal_initializer(stddev=0.02)):
    print(scope)
    with tf.variable_scope(scope, reuse = reuse):
        if weight_initializer is None:
            print("Initializing weights")
            w = tf.get_variable(name='weights',
                                shape = kernel_size + [num_outputs] + [x.get_shape()[-1]],
                                initializer=weight_initializer_type,
                                dtype = tf.float32, trainable=trainable)
        else:
            print("Loading weights")
            w = tf.get_variable(name='weights',
                                initializer=weight_initializer,
                                dtype = tf.float32, trainable=trainable)

        output_shape = [tf.shape(x)[0], tf.shape(x)[1] * stride[0], tf.shape(x)[2] * stride[1], num_outputs]

        conv_trans = tf.nn.conv2d_transpose(x, w,
                                            output_shape = output_shape,
                                            strides = [1] + stride + [1],
                                            padding = pad)

        if if_bias:
            if bias_initializer is None:
                print("Initializing biases")
                conv_trans = conv_trans + bias_variable([num_outputs], trainable=trainable)
            else:
                print("Load biases")
                conv_trans = conv_trans + bias_variable([num_outputs], trainable=trainable, bias_initializer=bias_initializer)


        return conv_trans

def conv2d_specnorm(input_, num_outputs, kernel_size=[4,4], stride=[1,1], pad='SAME', if_bias=True, trainable=True, reuse=False,
           scope='conv2d', weight_initializer=None, bias_initializer=None, u_weight=None,
           weight_initializer_type=tf.random_normal_initializer(stddev=0.02)):
    print(scope)
    with tf.variable_scope(scope, reuse = reuse):
        if weight_initializer is None:
            print("Initializing weights")
            w = tf.get_variable(name='weights',
                                shape = kernel_size +  [input_.get_shape()[-1]] + [num_outputs],
                                initializer = weight_initializer_type,
                                dtype = tf.float32, trainable=trainable)
        else:
            print("Loading weights")
            w = tf.get_variable(name='weights',
                                initializer = weight_initializer,
                                dtype = tf.float32, trainable=trainable)

        conv = tf.nn.conv2d(input_, spectral_norm(w, 1, u_weight=u_weight),
                            padding = pad,
                            strides = [1] + stride + [1])

        if if_bias:
            if bias_initializer is None:
                print("Initializing biases")
                conv=conv+bias_variable([num_outputs], trainable=trainable)
            else:
                print("Loading biases")
                conv=conv+bias_variable([num_outputs], trainable=trainable, bias_initializer=bias_initializer)

        return conv

def conv2d_transpose_specNorm(x, num_outputs, kernel_size = (4,4), stride= (1,1), pad='SAME', if_bias=True,
                     reuse=False, scope = "conv2d_transpose", trainable = True, weight_initializer=None,
                     bias_initializer=None, u_weight=None, weight_initializer_type=tf.random_normal_initializer(stddev=0.02)):
    print(scope)
    with tf.variable_scope(scope, reuse = reuse):
        if weight_initializer is None:
            print("Initializing weights")
            w = tf.get_variable(name='weights',
                                shape = kernel_size + [num_outputs] + [x.get_shape()[-1]],
                                initializer=weight_initializer_type,
                                dtype = tf.float32, trainable=trainable)
        else:
            print("Loading weights")
            w = tf.get_variable(name='weights',
                                initializer=weight_initializer,
                                dtype = tf.float32, trainable=trainable)

        output_shape = [tf.shape(x)[0], tf.shape(x)[1] * stride[0], tf.shape(x)[2] * stride[1], num_outputs]

        conv_trans = tf.nn.conv2d_transpose(x, spectral_norm(w, 1, u_weight),
                                            output_shape = output_shape,
                                            strides = [1] + stride + [1],
                                            padding = pad)

        if if_bias:
            if bias_initializer is None:
                print("Initializing biases")
                conv_trans = conv_trans + bias_variable([num_outputs], trainable=trainable)
            else:
                print("Load biases")
                conv_trans = conv_trans + bias_variable([num_outputs], trainable=trainable, bias_initializer=bias_initializer)


        return conv_trans

def conv3d(input_, num_outputs, pad = "SAME", reuse = False, kernel_size = [4,4,4], stride = [2,2,2], if_bias= True,
           trainable = True, scope = "conv3d", weight_initializer=None, bias_initializer=None, weight_initializer_type = tf.random_normal_initializer(stddev=0.02)):
    print(scope)
    with tf.variable_scope(scope, reuse = reuse):
        if weight_initializer is None:
            print("Initialise weight")
            w = tf.get_variable(name='weights',
                                trainable=trainable,
                                shape=kernel_size + [input_.get_shape()[-1]] + [num_outputs],
                                initializer=weight_initializer_type,
                                dtype=tf.float32)
        else:
            print("Loading weight")
            w = tf.get_variable(name='weights',
                                trainable=trainable,
                                initializer=weight_initializer,
                                dtype=tf.float32)

        conv = tf.nn.conv3d(input_, w,
                            padding = pad,
                            strides = [1] + stride + [1])

        if if_bias:
            if bias_initializer is None:
                print("Initialise bias")
                conv = conv + bias_variable([num_outputs], trainable=trainable)
            else:
                print("Loading bias")
                conv = conv + bias_variable([num_outputs], trainable=trainable, bias_initializer=bias_initializer)

        return conv



def conv3d_transpose(x, num_output, kernel_size = (4,4), stride= (1,1), pad='SAME', if_bias=True,
                     reuse=False, scope = "conv3d_transpose", trainable = True, weight_initializer=None,
                     bias_initializer=None, weight_initializer_type=tf.random_normal_initializer(stddev=0.02)):

    print(scope)
    with tf.variable_scope(scope, reuse = reuse):
        if weight_initializer is None:
            print("Initializing weights")
            w = tf.get_variable(name='weights',
                                shape = kernel_size + [num_output] + [x.get_shape().as_list()[-1]],
                                initializer=weight_initializer_type,
                                dtype = tf.float32, trainable=trainable)
        else:
            print("Loading weights")
            w = tf.get_variable(name='weights',
                                initializer=weight_initializer,
                                dtype = tf.float32, trainable=trainable)
        print("W " + str(w.get_shape()))
        output_shape = [tf.shape(x)[0], tf.shape(x)[1] * stride[0], tf.shape(x)[2] * stride[1], tf.shape(x)[3] * stride[2], num_output]

        conv_trans = tf.nn.conv3d_transpose(x, w,
                                            output_shape = output_shape,
                                            strides = [1] + stride + [1],
                                            padding = pad)

        if if_bias:
            if bias_initializer is None:
                print("Initializing biases")
                conv_trans = conv_trans + bias_variable([num_output], trainable=trainable)
            else:
                print("Load biases")
                conv_trans = conv_trans + bias_variable([num_output], trainable=trainable, bias_initializer=bias_initializer)


        return conv_trans

def conv3d_transpose_specNorm(x, num_output, kernel_size = (4,4), stride= (1,1), pad='SAME', if_bias=True,
                     reuse=False, scope = "conv3d_transpose", trainable = True, weight_initializer=None,
                     bias_initializer=None, u_weight=None, weight_initializer_type=tf.random_normal_initializer(stddev=0.02)):

    print(scope)
    with tf.variable_scope(scope, reuse = reuse):
        if weight_initializer is None:
            print("Initializing weights")
            w = tf.get_variable(name='weights',
                                shape = kernel_size + [num_output] + [x.get_shape().as_list()[-1]],
                                initializer=weight_initializer_type,
                                dtype = tf.float32, trainable=trainable)
        else:
            print("Loading weights")
            w = tf.get_variable(name='weights',
                                initializer=weight_initializer,
                                dtype = tf.float32, trainable=trainable)
        print("W " + str(w.get_shape()))
        output_shape = [tf.shape(x)[0], tf.shape(x)[1] * stride[0], tf.shape(x)[2] * stride[1], tf.shape(x)[3] * stride[2], num_output]

        conv_trans = tf.nn.conv3d_transpose(x, spectral_norm(w, 1, u_weight),
                                            output_shape = output_shape,
                                            strides = [1] + stride + [1],
                                            padding = pad)

        if if_bias:
            if bias_initializer is None:
                print("Initializing biases")
                conv_trans = conv_trans + bias_variable([num_output], trainable=trainable)
            else:
                print("Load biases")
                conv_trans = conv_trans + bias_variable([num_output], trainable=trainable, bias_initializer=bias_initializer)


        return conv_trans

def fully_connected(input_, output_size, scope = 'fully_connected', if_bias=True,
                    weight_initializer=None, bias_initializer=None,  reuse = False, trainable = True,
                    weight_initializer_type=tf.random_normal_initializer(stddev=0.02)):
    print(scope)
    if (type(input_)== np.ndarray):
        shape = input_.shape
    else:
        shape = input_.get_shape().as_list()
        # shape = tf.shape(input_).value.as_list()
    with tf.variable_scope(scope, reuse = reuse):
        if weight_initializer is None:
            print("Initializing weights")
            matrix = tf.get_variable("weights", [shape[-1], output_size], initializer=weight_initializer_type, dtype = tf.float32, trainable=trainable)
        else:
            print("Loading weights")
            matrix = tf.get_variable("weights", initializer=weight_initializer, dtype=tf.float32, trainable=trainable)

        fc = tf.matmul(input_, matrix)
        if if_bias:
            if bias_initializer is None:
                print("Initializing biases")
                fc = fc + bias_variable([output_size], trainable=trainable)
            else:
                print("Load biases")
                fc = fc + bias_variable([output_size], bias_initializer, trainable=trainable)
        return fc

def save_txt_file(pred, name, SAVE_DIR):
    with open(os.path.join(SAVE_DIR, "{0}.txt".format(name)), 'w') as fp:
        for i in pred:
            # print(tuple(point.tolist()))
            fp.write("{0}\n".format(i))

def transform_tensor_to_image (tensor):
    t = tf.transpose(tensor, [0 , 2, 1, 3])
    return t[:,::-1, :, :]

def transform_voxel_to_match_image(tensor):
    tensor = tf.transpose(tensor, [0, 2, 1, 3, 4])
    tensor = tensor[:, ::-1, :, :, :]
    return tensor

def transform_image_to_match_voxel(tensor):
    tensor = tf.transpose(tensor, [0, 2, 1, 3])
    tensor = tensor[:, ::-1, :, :]
    return tensor

def np_transform_tensor_to_image (tensor):
    t = np.transpose(tensor, [0, 2, 1, 3])
    return t

try:
  image_summary = tf.image_summary
  scalar_summary = tf.scalar_summary
  histogram_summary = tf.histogram_summary
  merge_summary = tf.merge_summary
  SummaryWriter = tf.train.SummaryWriter
except:
  image_summary = tf.summary.image
  scalar_summary = tf.summary.scalar
  histogram_summary = tf.summary.histogram
  merge_summary = tf.summary.merge
  SummaryWriter = tf.summary.FileWriter

def sigmoid_cross_entropy_with_logits(x, y):
    try:
        return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, labels=y)
    except:
        return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, targets=y)

#===========================================================================================================
#Activation functions
#===========================================================================================================

def lrelu(x, leak=0.2, name="lrelu"):
  return tf.maximum(x, leak*x)

#===========================================================================================================
#Normalization
#===========================================================================================================
def AdaIn(features, scale, bias):
    """
    Adaptive instance normalization component. Works with both 4D and 5D tensors
    :features: features to be normalized
    :scale: scaling factor. This would otherwise be calculated as the sigma from a "style" features in style transfer
    :bias: bias factor. This would otherwise be calculated as the mean from a "style" features in style transfer
    """

    mean, variance = tf.nn.moments(features, list(range(len(features.get_shape())))[1:-1],
                                   keep_dims=True)  # Only consider spatial dimension
    sigma = tf.rsqrt(variance + 1e-8)
    normalized = (features - mean) * sigma
    scale_broadcast = tf.reshape(scale, tf.shape(mean))
    bias_broadcast = tf.reshape(bias, tf.shape(mean))
    normalized = scale_broadcast * normalized
    normalized += bias_broadcast
    return normalized

def instance_norm(input, name="instance_norm", return_mean=False):
    """
    Taken from https://github.com/xhujoy/CycleGAN-tensorflow/blob/master/module.py
    :param input:
    :param name:
    :return:
    """
    with tf.variable_scope(name):
        depth = input.get_shape()[3]
        scale = tf.get_variable("scale", [depth],
                                initializer=tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32))
        offset = tf.get_variable("offset", [depth], initializer=tf.constant_initializer(0.0))
        mean, variance = tf.nn.moments(input, axes=[1, 2], keep_dims=True)
        epsilon = 1e-5
        inv = tf.rsqrt(variance + epsilon)
        normalized = (input - mean) * inv
        if return_mean:
            return scale * normalized + offset, mean, variance
        else:
            return scale * normalized + offset

def l2_norm(v, eps=1e-12):
    return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps)

def spectral_norm(w, iteration=1, u_weight=None):
    w_shape = w.shape.as_list()
    w = tf.reshape(w, [-1, w_shape[-1]])
    if u_weight is None:
        u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.truncated_normal_initializer(), trainable=False)
    else:
        u = u_weight

    u_hat = u
    v_hat = None
    for i in range(iteration):
        """
        power iteration
        Usually iteration = 1 will be enough
        """
        v_ = tf.matmul(u_hat, tf.transpose(w))
        v_hat = l2_norm(v_)

        u_ = tf.matmul(v_hat, w)
        u_hat = l2_norm(u_)

    sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
    w_norm = w / sigma

    with tf.control_dependencies([u.assign(u_hat)]):
        w_norm = tf.reshape(w_norm, w_shape)

    return w_norm


#===========================================================================================================
#Convolutions
#===========================================================================================================
def conv_out_size_same(size, stride):
  return int(math.ceil(float(size) / float(stride)))


def get_weight(shape, gain=np.sqrt(2), use_wscale=False, fan_in=None):
    if fan_in is None:
        fan_in = np.prod(shape[:-1])
    print ("current", shape[:-1], fan_in)
    std = gain / np.sqrt(fan_in) # He init

    if use_wscale:
        wscale = tf.constant(np.float32(std), name='wscale')
        return tf.get_variable('weight', shape=shape, initializer=tf.initializers.random_normal()) * wscale
    else:
        return tf.get_variable('weight', shape=shape, initializer=tf.initializers.random_normal(0, std))


def conv2d(input_, output_dim,
       k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,
       name="conv2d", padding='SAME'):
    with tf.variable_scope(name):
        w = tf.get_variable('weights', [k_h, k_w, input_.get_shape()[-1], output_dim],
                  initializer=tf.truncated_normal_initializer(stddev=stddev))
        conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding=padding)

        biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))
        conv = tf.reshape(tf.nn.bias_add(conv, biases), tf.shape(conv))

        return conv


def conv2d_specNorm(input_, output_dim,
       k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,
       name="conv2dSpectral", padding='SAME'):
  with tf.variable_scope(name):
    w = tf.get_variable('weights', [k_h, k_w, input_.get_shape()[-1], output_dim],
              initializer=tf.truncated_normal_initializer(stddev=stddev))
    conv = tf.nn.conv2d(input_, spectral_norm(w), strides=[1, d_h, d_w, 1], padding=padding)

    biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))
    conv = tf.reshape(tf.nn.bias_add(conv, biases), tf.shape(conv))

    return conv

def conv3d(input_, output_dim,
       k_h=5, k_w=5, k_d=5, d_h=2, d_w=2, d_d=2, stddev=0.02,
       name="conv3d", padding='SAME'):
  with tf.variable_scope(name):
    w = tf.get_variable('weights', [k_h, k_w, k_d, input_.get_shape()[-1], output_dim],
              initializer=tf.truncated_normal_initializer(stddev=stddev))
    conv = tf.nn.conv3d(input_, w, strides=[1, d_h, d_w, d_d, 1], padding=padding)

    biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))
    conv = tf.reshape(tf.nn.bias_add(conv, biases), tf.shape(conv))

    return conv

def conv3d_specNorm(input_, output_dim,
       k_h=5, k_w=5, k_d=5, d_h=2, d_w=2, d_d=2, stddev=0.02,
       name="conv3dSpectral", padding='SAME'):
  with tf.variable_scope(name):
    w = tf.get_variable('weights', [k_h, k_w, k_d, input_.get_shape()[-1], output_dim],
              initializer=tf.truncated_normal_initializer(stddev=stddev))
    conv = tf.nn.conv3d(input_, spectral_norm(w), strides=[1, d_h, d_w, d_d, 1], padding=padding)

    biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))
    conv = tf.reshape(tf.nn.bias_add(conv, biases), tf.shape(conv))

    return conv

def deconv2d(input_, output_shape,
       k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,
       name="deconv2d", with_w=False):
  with tf.variable_scope(name):
    # filter : [height, width, output_channels, in_channels]
    w = tf.get_variable('weights', [k_h, k_w, output_shape[-1], input_.get_shape()[-1]],
              initializer=tf.random_normal_initializer(stddev=stddev))


    deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape, strides=[1, d_h, d_w, 1])

    biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0))
    deconv = tf.reshape(tf.nn.bias_add(deconv, biases), tf.shape(deconv))

    if with_w:
      return deconv, w, biases
    else:
      return deconv


def deconv2d_specNorm(input_, output_shape,
             k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,
             name="deconv2d", with_w=False):
    with tf.variable_scope(name):
        # filter : [height, width, output_channels, in_channels]
        w = tf.get_variable('weights', [k_h, k_w, output_shape[-1], input_.get_shape()[-1]],
                            initializer=tf.random_normal_initializer(stddev=stddev))
        deconv = tf.nn.conv2d_transpose(input_, spectral_norm(w), output_shape=output_shape,
                                        strides=[1, d_h, d_w, 1])

        biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0))
        deconv = tf.reshape(tf.nn.bias_add(deconv, biases), tf.shape(deconv))

        if with_w:
            return deconv, w, biases
        else:
            return deconv


def deconv3d(input_, output_shape,
             k_h=5, k_w=5, k_d=5, d_h=2, d_w=2, d_d=2, stddev=0.02,
             name="deconv3d", with_w=False):
  with tf.variable_scope(name):
    # filter : [height, width, output_channels, in_channels]
    w = tf.get_variable('weights', [k_h, k_w, k_d, output_shape[-1], input_.get_shape()[-1]],
                        initializer=tf.random_normal_initializer(stddev=stddev))

    deconv = tf.nn.conv3d_transpose(input_, w, output_shape=output_shape,
                                      strides=[1, d_h, d_w, d_d, 1])


    biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0))
    deconv = tf.reshape(tf.nn.bias_add(deconv, biases), tf.shape(deconv))

    if with_w:
      return deconv, w, biases
    else:
      return deconv

def deconv3d_specNorm(input_, output_shape,
             k_h=5, k_w=5, k_d=5, d_h=2, d_w=2, d_d=2, stddev=0.02,
             name="deconv3dSpectral", with_w=False):
  with tf.variable_scope(name):
    w = tf.get_variable('weights', [k_h, k_w, k_d, output_shape[-1], input_.get_shape()[-1]],
              initializer=tf.truncated_normal_initializer(stddev=stddev))
    deconv = tf.nn.conv3d_transpose(input_, spectral_norm(w), output_shape=output_shape, strides=[1, d_h, d_w, d_d, 1], padding='SAME')

    biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0))
    deconv = tf.reshape(tf.nn.bias_add(deconv, biases), tf.shape(deconv))

    if with_w:
      return deconv, w, biases
    else:
      return deconv

def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False):
  shape = input_.get_shape().as_list()
  with tf.variable_scope(scope or "Linear"):
    matrix = tf.get_variable("weights", [shape[1], output_size], tf.float32,
                 tf.random_normal_initializer(stddev=stddev))
    bias = tf.get_variable("biases", [output_size],
      initializer=tf.constant_initializer(bias_start))
    if with_w:
      return tf.matmul(input_, matrix) + bias, matrix, bias
    else:
      return tf.matmul(input_, matrix) + bias

def linear_specNorm(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False):
  shape = input_.get_shape().as_list()

  with tf.variable_scope(scope or "Linear"):
    matrix = spectral_norm(tf.get_variable("weights", [shape[1], output_size], tf.float32,
                 tf.random_normal_initializer(stddev=stddev)))
    bias = tf.get_variable("biases", [output_size],
      initializer=tf.constant_initializer(bias_start))
    if with_w:
      return tf.matmul(input_, matrix) + bias, matrix, bias
    else:
      return tf.matmul(input_, matrix) + bias


def flatten(x) :
    return tf.layers.flatten(x)


def tf_repeat(x, n_repeats):
    #Repeat X for n_repeats time along 0 axis
    #Return a 1D tensor of total number of elements
    rep = tf.ones(shape=[1, n_repeats], dtype = 'int32')
    x = tf.matmul(tf.reshape(x, (-1,1)), rep)
    return tf.reshape(x, [-1])

def tf_interpolate(voxel, x, y, z, out_size):
    """
    Trilinear interpolation for batch of voxels
    :param voxel: The whole voxel grid
    :param x,y,z: indices of voxel
    :param output_size: output size of voxel
    :return:
    """
    batch_size = tf.shape(voxel)[0]
    height = tf.shape(voxel)[1]
    width = tf.shape(voxel)[2]
    depth = tf.shape(voxel)[3]
    n_channels = tf.shape(voxel)[4]

    x = tf.cast(x, 'float32')
    y = tf.cast(y, 'float32')
    z = tf.cast(z, 'float32')

    out_height = out_size[1]
    out_width = out_size[2]
    out_depth = out_size[3]
    out_channel = out_size[4]

    zero = tf.zeros([], dtype='int32')
    max_y = tf.cast(height - 1, 'int32')
    max_x = tf.cast(width - 1, 'int32')
    max_z = tf.cast(depth - 1, 'int32')

    # do sampling
    x0 = tf.cast(tf.floor(x), 'int32')
    x1 = x0 + 1
    y0 = tf.cast(tf.floor(y), 'int32')
    y1 = y0 + 1
    z0 = tf.cast(tf.floor(z), 'int32')
    z1 = z0 + 1

    x0 = tf.clip_by_value(x0, zero, max_x)
    x1 = tf.clip_by_value(x1, zero, max_x)
    y0 = tf.clip_by_value(y0, zero, max_y)
    y1 = tf.clip_by_value(y1, zero, max_y)
    z0 = tf.clip_by_value(z0, zero, max_z)
    z1 = tf.clip_by_value(z1, zero, max_z)

    #A 1D tensor of base indicies describe First index for each shape/map in the whole batch
    #tf.range(batch_size) * width * height * depth : Element to repeat. Each selement in the list is incremented by width*height*depth amount
    # out_height * out_width * out_depth: n of repeat. Create chunks of out_height*out_width*out_depth length with the same value created by tf.rage(batch_size) *width*height*dept
    base = tf_repeat(tf.range(batch_size) * width * height * depth, out_height * out_width * out_depth)

    #Find the Z element of each index

    base_z0 = base + z0 * width * height
    base_z1 = base + z1 * width * height
    #Find the Y element based on Z
    base_z0_y0 = base_z0 + y0 * width
    base_z0_y1 = base_z0 + y1 * width
    base_z1_y0 = base_z1 + y0 * width
    base_z1_y1 = base_z1 + y1 * width

    # Find the X element based on Y, Z for Z=0
    idx_a = base_z0_y0 + x0
    idx_b = base_z0_y1 + x0
    idx_c = base_z0_y0 + x1
    idx_d = base_z0_y1 + x1
    # Find the X element based on Y,Z for Z =1
    idx_e = base_z1_y0 + x0
    idx_f = base_z1_y1 + x0
    idx_g = base_z1_y0 + x1
    idx_h = base_z1_y1 + x1

    # use indices to lookup pixels in the flat image and restore
    # channels dim
    voxel_flat = tf.reshape(voxel, [-1, n_channels])
    voxel_flat = tf.cast(voxel_flat, 'float32')
    Ia = tf.gather(voxel_flat, idx_a)
    Ib = tf.gather(voxel_flat, idx_b)
    Ic = tf.gather(voxel_flat, idx_c)
    Id = tf.gather(voxel_flat, idx_d)
    Ie = tf.gather(voxel_flat, idx_e)
    If = tf.gather(voxel_flat, idx_f)
    Ig = tf.gather(voxel_flat, idx_g)
    Ih = tf.gather(voxel_flat, idx_h)

    # and finally calculate interpolated values
    x0_f = tf.cast(x0, 'float32')
    x1_f = tf.cast(x1, 'float32')
    y0_f = tf.cast(y0, 'float32')
    y1_f = tf.cast(y1, 'float32')
    z0_f = tf.cast(z0, 'float32')
    z1_f = tf.cast(z1, 'float32')

    #First slice XY along Z where z=0
    wa = tf.expand_dims(((x1_f - x) * (y1_f - y) * (z1_f-z)), 1)
    wb = tf.expand_dims(((x1_f - x) * (y - y0_f) * (z1_f-z)), 1)
    wc = tf.expand_dims(((x - x0_f) * (y1_f - y) * (z1_f-z)), 1)
    wd = tf.expand_dims(((x - x0_f) * (y - y0_f) * (z1_f-z)), 1)
    # First slice XY along Z where z=1
    we = tf.expand_dims(((x1_f - x) * (y1_f - y) * (z-z0_f)), 1)
    wf = tf.expand_dims(((x1_f - x) * (y - y0_f) * (z-z0_f)), 1)
    wg = tf.expand_dims(((x - x0_f) * (y1_f - y) * (z-z0_f)), 1)
    wh = tf.expand_dims(((x - x0_f) * (y - y0_f) * (z-z0_f)), 1)


    output = tf.add_n([wa * Ia, wb * Ib, wc * Ic, wd * Id,  we * Ie, wf * If, wg * Ig, wh * Ih])
    return output

def tf_voxel_meshgrid(height, width, depth, homogeneous = False):
    with tf.variable_scope('voxel_meshgrid'):
        #Because 'ij' ordering is used for meshgrid, z_t and x_t are swapped (Think about order in 'xy' VS 'ij'
        z_t, y_t, x_t = tf.meshgrid(tf.range(depth, dtype = tf.float32),
                                    tf.range(height, dtype = tf.float32),
                                    tf.range(width, dtype = tf.float32), indexing='ij')
        #Reshape into a big list of slices one after another along the X,Y,Z direction
        x_t_flat = tf.reshape(x_t, (1, -1))
        y_t_flat = tf.reshape(y_t, (1, -1))
        z_t_flat = tf.reshape(z_t, (1, -1))

        #Vertical stack to create a (3,N) matrix for X,Y,Z coordinates
        grid = tf.concat([x_t_flat, y_t_flat, z_t_flat], axis=0)
        if homogeneous:
            ones = tf.ones_like(x_t_flat)
            grid = tf.concat([grid, ones], axis = 0)
        return grid

def tf_rotation_around_grid_centroid(view_params, shapenet_viewer = False):
    #This function returns a rotation matrix around a center with y-axis being the up vector, and a scale matrix.
    #It first rotates the matrix by the azimuth angle (theta) around y, then around X-axis by elevation angle (gamma)
    #return a rotation matrix in homogenous coordinate
    #The default Open GL camera is to looking towards the negative Z direction
    #This function is suitable when the silhoutte projection is done along the Z direction

    batch_size = tf.shape(view_params)[0]

    azimuth    = tf.reshape(view_params[:, 0], (batch_size, 1, 1))
    elevation  = tf.reshape(view_params[:, 1], (batch_size, 1, 1))

    # azimuth = azimuth
    if shapenet_viewer == False:
        azimuth = (azimuth - tf.constant(math.pi * 0.5))

    #========================================================
    #Because tensorflow does not allow tensor item replacement
    #A new matrix needs to be created from scratch by concatenating different vectors into rows and stacking them up
    #Batch Rotation Y matrixes
    ones = tf.ones_like(azimuth)
    zeros = tf.zeros_like(azimuth)
    batch_Rot_Y = tf.concat([
        tf.concat([tf.cos(azimuth),  zeros, -tf.sin(azimuth), zeros], axis=2),
        tf.concat([zeros, ones,  zeros,zeros], axis=2),
        tf.concat([tf.sin(azimuth),  zeros, tf.cos(azimuth), zeros], axis=2),
        tf.concat([zeros, zeros, zeros, ones], axis=2)], axis=1)

    #Batch Rotation Z matrixes
    batch_Rot_Z = tf.concat([
        tf.concat([tf.cos(elevation),  tf.sin(elevation),  zeros, zeros], axis=2),
        tf.concat([-tf.sin(elevation), tf.cos(elevation),  zeros, zeros], axis=2),
        tf.concat([zeros, zeros, ones,  zeros], axis=2),
        tf.concat([zeros, zeros, zeros, ones], axis=2)], axis=1)


    transformation_matrix = tf.matmul(batch_Rot_Z, batch_Rot_Y)
    if tf.shape(view_params)[1] == 2:
        return transformation_matrix
    else:
    #Batch Scale matrixes:
        scale = tf.reshape(view_params[:, 2], (batch_size, 1, 1))
        batch_Scale= tf.concat([
            tf.concat([scale,  zeros,  zeros, zeros], axis=2),
            tf.concat([zeros, scale,  zeros, zeros], axis=2),
            tf.concat([zeros, zeros,  scale,  zeros], axis=2),
            tf.concat([zeros, zeros,  zeros, ones], axis=2)], axis=1)
    return transformation_matrix, batch_Scale

def tf_rotation_resampling(voxel_array, transformation_matrix, params, Scale_matrix = None, size=64, new_size=128):
    """
    Batch transformation and resampling function
    :param voxel_array: batch of voxels. Shape = [batch_size, height, width, depth, features]
    :param transformation_matrix: Rotation matrix. Shape = [batch_size, height, width, depth, features]
    :param size: original size of the voxel array
    :param new_size: size of the resampled array
    :return: transformed voxel array
    """
    batch_size = tf.shape(voxel_array)[0]
    n_channels = voxel_array.get_shape()[4].value
    target = tf.zeros([ batch_size, new_size, new_size, new_size])
    #Aligning the centroid of the object (voxel grid) to origin for rotation,
    #then move the centroid back to the original position of the grid centroid
    T = tf.constant([[1,0,0, -size * 0.5],
                  [0,1,0, -size * 0.5],
                  [0,0,1, -size * 0.5],
                  [0,0,0,1]])
    T = tf.tile(tf.reshape(T, (1, 4, 4)), [batch_size, 1, 1])

    # However, since the rotated grid might be out of bound for the original grid size,
    # move the rotated grid to a new bigger grid
    T_new_inv = tf.constant([[1, 0, 0, new_size * 0.5],
                             [0, 1, 0, new_size * 0.5],
                             [0, 0, 1, new_size * 0.5],
                             [0, 0, 0, 1]])
    T_new_inv = tf.tile(tf.reshape(T_new_inv, (1, 4, 4)), [batch_size, 1, 1])


    # Add the actual shifting in x and y dimension accoding to input param
    x_shift = tf.reshape(params[:, 3], (batch_size, 1, 1))
    y_shift = tf.reshape(params[:, 4], (batch_size, 1, 1))
    z_shift = tf.reshape(params[:, 5], (batch_size, 1, 1))
    # ========================================================
    # Because tensorflow does not allow tensor item replacement
    # A new matrix needs to be created from scratch by concatenating different vectors into rows and stacking them up
    # Batch Rotation Y matrixes
    ones = tf.ones_like(x_shift)
    zeros = tf.zeros_like(x_shift)

    T_translate = tf.concat([
        tf.concat([ones, zeros, zeros, x_shift], axis=2),
        tf.concat([zeros, ones, zeros, y_shift], axis=2),
        tf.concat([zeros, zeros, ones, z_shift], axis=2),
        tf.concat([zeros, zeros, zeros, ones], axis=2)], axis=1)
    total_M = tf.matmul(tf.matmul(tf.matmul(tf.matmul(T_new_inv, T_translate), Scale_matrix), transformation_matrix), T)


    try:
        total_M = tf.matrix_inverse(total_M)

        total_M = total_M[:, 0:3, :] #Ignore the homogenous coordinate so the results are 3D vectors
        grid = tf_voxel_meshgrid(new_size, new_size, new_size, homogeneous=True)
        grid = tf.tile(tf.reshape(grid, (1, tf.to_int32(grid.get_shape()[0]) , tf.to_int32(grid.get_shape()[1]))), [batch_size, 1, 1])
        grid_transform = tf.matmul(total_M, grid)
        x_s_flat = tf.reshape(grid_transform[:, 0, :], [-1])
        y_s_flat = tf.reshape(grid_transform[:, 1, :], [-1])
        z_s_flat = tf.reshape(grid_transform[:, 2, :], [-1])
        input_transformed = tf_interpolate(voxel_array, x_s_flat, y_s_flat, z_s_flat,[batch_size, new_size, new_size, new_size, n_channels])
        target= tf.reshape(input_transformed, [batch_size, new_size, new_size, new_size, n_channels])

        return target, grid_transform
    except tf.InvalidArgumentError:
        return None

def tf_3D_transform(voxel_array, view_params, size=64, new_size=128, shapenet_viewer=False):
    """
    Wrapper function to do 3D transformation
    :param voxel_array: batch of voxels. Shape = [batch_size, height, width, depth, features]
    :param transformation_matrix: Rotation matrix. Shape = [batch_size, height, width, depth, features]
    :param size: original size of the voxel array
    :param new_size: size of the resampled array
    :return: transformed voxel array
    """
    M, S = tf_rotation_around_grid_centroid(view_params[:, :3], shapenet_viewer=shapenet_viewer)
    target, grids = tf_rotation_resampling(voxel_array, M, params=view_params, Scale_matrix=S, size = size, new_size=new_size)
    return target

def generate_random_rotation_translation(batch_size, elevation_low=10, elevation_high=170, azimuth_low=0, azimuth_high=359,
                                         transX_low=-3, transX_high=3,
                                         transY_low=-3, transY_high=3,
                                         transZ_low=-3, transZ_high=3,
                                         scale_low=1.0, scale_high=1.0,
                                         with_translation=False, with_scale=False):
    params = np.zeros((batch_size, 6))
    column = np.arange(0, batch_size)
    azimuth = np.random.randint(azimuth_low, azimuth_high, (batch_size)).astype(np.float) * math.pi / 180.0
    temp = np.random.randint(elevation_low, elevation_high, (batch_size))
    elevation = (90. - temp.astype(np.float)) * math.pi / 180.0
    params[column, 0] = azimuth
    params[column, 1] = elevation

    if with_translation:
        shift_x = transX_low + np.random.random(batch_size) * (transX_high - transX_low)
        shift_y = transY_low + np.random.random(batch_size) * (transY_high - transY_low)
        shift_z = transZ_low + np.random.random(batch_size) * (transZ_high - transZ_low)
        params[column, 3] = shift_x
        params[column, 4] = shift_y
        params[column, 5] = shift_z

    if with_scale:
        scale = float(np.random.uniform(scale_low, scale_high))
        params[column, 2] = scale
    else:
        params[column, 2] = 1.0

    return params


pp = pprint.PrettyPrinter()

get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1])

def show_all_variables():
  model_vars = tf.trainable_variables()
  slim.model_analyzer.analyze_vars(model_vars, print_info=True)

def get_image(image_path, input_height, input_width,
              resize_height=64, resize_width=64,
              crop=True):
  image = load_webp(image_path)
  return transform(image, input_height, input_width,
                   resize_height, resize_width, crop)

def load_webp(img_path):
    im = Image.open(img_path)
    return np.asarray(im)


def merge(images, size):
  h, w = images.shape[1], images.shape[2]
  if (images.shape[3] in (3,4)):
    c = images.shape[3]
    img = np.zeros((h * size[0], w * size[1], c))
    for idx, image in enumerate(images):
      i = idx % size[1]
      j = idx // size[1]
      img[j * h:j * h + h, i * w:i * w + w, :] = image
    return img
  elif images.shape[3]==1:
    img = np.zeros((h * size[0], w * size[1]))
    for idx, image in enumerate(images):
      i = idx % size[1]
      j = idx // size[1]
      img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0]
    return img
  else:
    raise ValueError('in merge(images,size) images parameter '
                     'must have dimensions: HxW or HxWx3 or HxWx4')


def center_crop(x, crop_h, crop_w,
                resize_h=64, resize_w=64):
  if crop_w is None:
    crop_w = crop_h
  h, w = x.shape[:2]
  j = int(round((h - crop_h)/2.))
  i = int(round((w - crop_w)/2.))
  return scipy.misc.imresize(
      x[j:j+crop_h, i:i+crop_w], [resize_h, resize_w])

def transform(image, input_height, input_width,
              resize_height=64, resize_width=64, crop=True):
  if crop:
    cropped_image = center_crop(
      image, input_height, input_width,
      resize_height, resize_width)
  else:
    cropped_image = scipy.misc.imresize(image, [resize_height, resize_width])
  if len(cropped_image.shape) != 3: #In case of binary mask with no channels:
    cropped_image = np.expand_dims(cropped_image, -1)
  return np.array(cropped_image)[:, :, :3]/127.5 - 1.

def inverse_transform(images):
  return (images+1.)/2.

def image_manifold_size(num_images):
  manifold_h = int(np.floor(np.sqrt(num_images)))
  manifold_w = int(np.ceil(np.sqrt(num_images)))

  assert manifold_h * manifold_w == num_images
  return manifold_h, manifold_w

def to_bool(value):
    """
       Converts 'something' to boolean. Raises exception for invalid formats
           Possible True  values: 1, True, "1", "TRue", "yes", "y", "t"
           Possible False values: 0, False, None, [], {}, "", "0", "faLse", "no", "n", "f", 0.0, ...
    """
    if str(value).lower() == "true": return True
    if str(value).lower() == "false": return False
    raise Exception('Invalid value for boolean conversion: ' + str(value))

fatal: destination path 'binvox-rw-py' already exists and is not an empty directory.
*


In [0]:
from __future__ import division
import os
import sys
from glob import glob
import json
import shutil
import glob



#----------------------------------------------------------------------------

class HoloGAN(object):
  def __init__(self, sess, input_height=108, input_width=108, crop=True,
         output_height=64, output_width=64,
         gf_dim=64, df_dim=64,
         c_dim=3, dataset_name='lsun',
         input_fname_pattern='*.webp'):

    self.sess = sess
    self.crop = crop

    self.input_height = input_height
    self.input_width = input_width
    self.output_height = output_height
    self.output_width = output_width

    self.gf_dim = gf_dim
    self.df_dim = df_dim
    self.c_dim = c_dim

    self.dataset_name = dataset_name
    self.input_fname_pattern = input_fname_pattern
    self.data = glob.glob(os.path.join(IMAGE_PATH, self.input_fname_pattern))
    self.checkpoint_dir = LOGDIR

  def build(self, build_func_name):
      build_func = eval("self." + build_func_name)
      build_func()

  def build_HoloGAN(self):
    self.view_in = tf.placeholder(tf.float32, [None, 6], name='view_in')
    self.inputs = tf.placeholder(tf.float32, [None, self.output_height, self.output_width, self.c_dim], name='real_images')
    self.z = tf.placeholder(tf.float32, [None, cfg['z_dim']], name='z')
    inputs = self.inputs

    gen_func = eval("self." + (cfg['generator']))
    dis_func = eval("self." + (cfg['discriminator']))
    self.gen_view_func = eval(cfg['view_func'])

    self.G = gen_func(self.z, self.view_in)

    if str.lower(str(cfg["style_disc"])) == "true":
        print("Style Disc")
        self.D, self.D_logits, _, self.d_h1_r, self.d_h2_r, self.d_h3_r, self.d_h4_r = dis_func(inputs, cont_dim=cfg['z_dim'], reuse=False)
        self.D_, self.D_logits_, self.Q_c_given_x, self.d_h1_f, self.d_h2_f, self.d_h3_f, self.d_h4_f = dis_func(self.G, cont_dim=cfg['z_dim'], reuse=True)

        self.d_h1_loss = cfg["DStyle_lambda"] * (
                    tf.reduce_mean(sigmoid_cross_entropy_with_logits(self.d_h1_r, tf.ones_like(self.d_h1_r))) \
                    + tf.reduce_mean(sigmoid_cross_entropy_with_logits(self.d_h1_f, tf.zeros_like(self.d_h1_f))))
        self.d_h2_loss = cfg["DStyle_lambda"] * (
                    tf.reduce_mean(sigmoid_cross_entropy_with_logits(self.d_h2_r, tf.ones_like(self.d_h2_r))) \
                    + tf.reduce_mean(sigmoid_cross_entropy_with_logits(self.d_h2_f, tf.zeros_like(self.d_h2_f))))
        self.d_h3_loss = cfg["DStyle_lambda"] * (
                    tf.reduce_mean(sigmoid_cross_entropy_with_logits(self.d_h3_r, tf.ones_like(self.d_h3_r))) \
                    + tf.reduce_mean(sigmoid_cross_entropy_with_logits(self.d_h3_f, tf.zeros_like(self.d_h3_f))))
        self.d_h4_loss = cfg["DStyle_lambda"] * (
                    tf.reduce_mean(sigmoid_cross_entropy_with_logits(self.d_h4_r, tf.ones_like(self.d_h4_r))) \
                    + tf.reduce_mean(sigmoid_cross_entropy_with_logits(self.d_h4_f, tf.zeros_like(self.d_h4_f))))
    else:
        self.D, self.D_logits, _ = dis_func(inputs, cont_dim=cfg['z_dim'], reuse=False)
        self.D_, self.D_logits_, self.Q_c_given_x = dis_func(self.G, cont_dim=cfg['z_dim'], reuse=True)


    self.d_loss_real = tf.reduce_mean(sigmoid_cross_entropy_with_logits(self.D_logits, tf.ones_like(self.D)))
    self.d_loss_fake = tf.reduce_mean(sigmoid_cross_entropy_with_logits(self.D_logits_, tf.zeros_like(self.D_)))
    self.d_loss = self.d_loss_real + self.d_loss_fake
    self.g_loss = tf.reduce_mean(sigmoid_cross_entropy_with_logits(self.D_logits_, tf.ones_like(self.D_)))


    if str.lower(str(cfg["style_disc"])) == "true":
        print("Style disc")
        self.d_loss = self.d_loss + self.d_h1_loss + self.d_h2_loss + self.d_h3_loss + self.d_h4_loss
    #====================================================================================================================
    #Identity loss

    self.q_loss = cfg["lambda_latent"] * tf.reduce_mean(tf.square(self.Q_c_given_x - self.z))
    self.d_loss = self.d_loss + self.q_loss
    self.g_loss = self.g_loss + self.q_loss


    self.d_loss_real_sum = scalar_summary("d_loss_real", self.d_loss_real)
    self.d_loss_fake_sum = scalar_summary("d_loss_fake", self.d_loss_fake)
    self.g_loss_sum = scalar_summary("g_loss", self.g_loss)
    self.d_loss_sum = scalar_summary("d_loss", self.d_loss)

    t_vars = tf.trainable_variables()

    self.d_vars = [var for var in t_vars if 'd_' in var.name]
    self.g_vars = [var for var in t_vars if 'g_' in var.name]

    self.saver = tf.train.Saver()

  def train_HoloGAN(self, config):
      self.d_lr_in = tf.placeholder(tf.float32, None, name='d_eta')
      self.g_lr_in = tf.placeholder(tf.float32, None, name='d_eta')

      d_optim = tf.train.AdamOptimizer(cfg['d_eta'], beta1=cfg['beta1'], beta2=cfg['beta2']).minimize(self.d_loss, var_list=self.d_vars)
      g_optim = tf.train.AdamOptimizer(cfg['g_eta'], beta1=cfg['beta1'], beta2=cfg['beta2']).minimize(self.g_loss, var_list=self.g_vars)

      tf.global_variables_initializer().run()

      shutil.copyfile(sys.argv[1], os.path.join(LOGDIR, 'config.json'))
      self.g_sum = merge_summary([self.d_loss_fake_sum, self.g_loss_sum])
      self.d_sum = merge_summary([self.d_loss_real_sum, self.d_loss_sum])
      self.writer = SummaryWriter(LOGDIR, self.sess.graph)

      # Sample noise Z and view parameters to test during training
      sample_z = self.sampling_Z(cfg['z_dim'], str(cfg['sample_z']))
      sample_view = self.gen_view_func(cfg['batch_size'],
                                       cfg['ele_low'], cfg['ele_high'],
                                       cfg['azi_low'], cfg['azi_high'],
                                       cfg['scale_low'], cfg['scale_high'],
                                       cfg['x_low'], cfg['x_high'],
                                       cfg['y_low'], cfg['y_high'],
                                       cfg['z_low'], cfg['z_high'],
                                       with_translation=False,
                                       with_scale=to_bool(str(cfg['with_translation'])))
      sample_files = self.data[0:cfg['batch_size']]

      if config.dataset == "cats" or config.dataset == "cars":
          sample_images = [get_image(sample_file,
                                    input_height=self.input_height,
                                    input_width=self.input_width,
                                    resize_height=self.output_height,
                                    resize_width=self.output_width,
                                    crop=False) for sample_file in sample_files]
      else:
          sample_images = [get_image(sample_file,
                                    input_height=self.input_height,
                                    input_width=self.input_width,
                                    resize_height=self.output_height,
                                    resize_width=self.output_width,
                                    crop=True) for sample_file in sample_files]

      counter = 1
      start_time = time.time()
      could_load, checkpoint_counter = self.load(self.checkpoint_dir)
      if could_load:
          counter = checkpoint_counter
          print(" [*] Load SUCCESS")
      else:
          print(" [!] Load failed...")

      self.data = glob.glob(os.path.join(IMAGE_PATH, self.input_fname_pattern))
      d_lr = cfg['d_eta']
      g_lr = cfg['g_eta']
      for epoch in range(cfg['max_epochs']):
          d_lr = d_lr if epoch < cfg['epoch_step'] else d_lr * (cfg['max_epochs'] - epoch) / (cfg['max_epochs'] - cfg['epoch_step'])
          g_lr = g_lr if epoch < cfg['epoch_step'] else g_lr * (cfg['max_epochs'] - epoch) / (cfg['max_epochs'] - cfg['epoch_step'])

          random.shuffle(self.data)
          batch_idxs = min(len(self.data), config.train_size) // cfg['batch_size']

          for idx in range(0, batch_idxs):
              batch_files = self.data[idx * cfg['batch_size']:(idx + 1) * cfg['batch_size']]
              if config.dataset == "cats" or config.dataset == "cars":
                  batch_images = [get_image(batch_file,
                                    input_height=self.input_height,
                                    input_width=self.input_width,
                                    resize_height=self.output_height,
                                    resize_width=self.output_width,
                                    crop=False) for batch_file in batch_files]
              else:
                  batch_images = [get_image(batch_file,
                                    input_height=self.input_height,
                                    input_width=self.input_width,
                                    resize_height=self.output_height,
                                    resize_width=self.output_width,
                                    crop=self.crop) for batch_file in batch_files]

              batch_z = self.sampling_Z(cfg['z_dim'], str(cfg['sample_z']))
              batch_view = self.gen_view_func(cfg['batch_size'],
                                       cfg['ele_low'], cfg['ele_high'],
                                       cfg['azi_low'], cfg['azi_high'],
                                       cfg['scale_low'], cfg['scale_high'],
                                       cfg['x_low'], cfg['x_high'],
                                       cfg['y_low'], cfg['y_high'],
                                       cfg['z_low'], cfg['z_high'],
                                       with_translation=False,
                                       with_scale=to_bool(str(cfg['with_translation'])))

              feed = {self.inputs: batch_images,
                      self.z: batch_z,
                      self.view_in: batch_view,
                      self.d_lr_in: d_lr,
                      self.g_lr_in: g_lr}
              # Update D network
              _, summary_str = self.sess.run([d_optim, self.d_sum],feed_dict=feed)
              self.writer.add_summary(summary_str, counter)
              # Update G network
              _, summary_str = self.sess.run([g_optim, self.g_sum], feed_dict=feed)
              self.writer.add_summary(summary_str, counter)
              # Run g_optim twice
              _, summary_str = self.sess.run([g_optim, self.g_sum],  feed_dict=feed)
              self.writer.add_summary(summary_str, counter)

              errD_fake = self.d_loss_fake.eval(feed)
              errD_real = self.d_loss_real.eval(feed)
              errG = self.g_loss.eval(feed)
              errQ = self.q_loss.eval(feed)

              counter += 1
              print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f, q_loss: %.8f" \
                    % (epoch, idx, batch_idxs,
                       time.time() - start_time, errD_fake + errD_real, errG, errQ))

              if np.mod(counter, 1000) == 1:
                  self.save(LOGDIR, counter)
                  feed_eval = {self.inputs: sample_images,
                               self.z: sample_z,
                               self.view_in: sample_view,
                               self.d_lr_in: d_lr,
                               self.g_lr_in: g_lr}
                  samples, d_loss, g_loss = self.sess.run(
                      [self.G, self.d_loss, self.g_loss],
                      feed_dict=feed_eval)
                  ren_img = inverse_transform(samples)
                  ren_img = np.clip(255 * ren_img, 0, 255).astype(np.uint8)
                  try:
                      scipy.misc.imsave(
                          os.path.join(OUTPUT_DIR, "{0}_GAN.png".format(counter)),
                          merge(ren_img, [cfg['batch_size'] // 4, 4]))
                      print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss))
                  except:
                      scipy.misc.imsave(
                          os.path.join(OUTPUT_DIR, "{0}_GAN.png".format(counter)),
                          ren_img[0])
                      print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss))

  def sample_HoloGAN(self, config):
      could_load, checkpoint_counter = self.load(self.checkpoint_dir)
      if could_load:
          counter = checkpoint_counter
          print(" [*] Load SUCCESS")
      else:
          print(" [!] Load failed...")
          return
      SAMPLE_DIR = os.path.join(OUTPUT_DIR, "samples")
      if not os.path.exists(SAMPLE_DIR):
          os.makedirs(SAMPLE_DIR)
      sample_z = self.sampling_Z(cfg['z_dim'], str(cfg['sample_z']))
      if config.rotate_azimuth:
          low  = cfg['azi_low']
          high = cfg['azi_high']
          step = 10
      elif config.rotate_elevation:
          low  = cfg['ele_low']
          high = cfg['ele_high']
          step = 5
      else:
          low  = 0
          high = 10
          step = 1

      for i in range(low, high, step):
          if config.rotate_azimuth:
              sample_view = np.tile(
                  np.array([i * math.pi / 180.0, 0 * math.pi / 180.0, 1.0, 0, 0, 0]), (cfg['batch_size'], 1))
          elif config.rotate_azimuth:
              sample_view = np.tile(
                  np.array([270 * math.pi / 180.0, (90 - i) * math.pi / 180.0, 1.0, 0, 0, 0]), (cfg['batch_size'], 1))
          else:
              sample_view = self.gen_view_func(cfg['batch_size'],
                                               cfg['ele_low'], cfg['ele_high'],
                                               cfg['azi_low'], cfg['azi_high'],
                                               cfg['scale_low'], cfg['scale_high'],
                                               cfg['x_low'], cfg['x_high'],
                                               cfg['y_low'], cfg['y_high'],
                                               cfg['z_low'], cfg['z_high'],
                                               with_translation=False,
                                               with_scale=to_bool(str(cfg['with_translation'])))

          feed_eval = {self.z: sample_z,
                       self.view_in: sample_view}

          samples = self.sess.run(self.G, feed_dict=feed_eval)
          ren_img = inverse_transform(samples)
          ren_img = np.clip(255 * ren_img, 0, 255).astype(np.uint8)
          try:
              scipy.misc.imsave(
                  os.path.join(SAMPLE_DIR, "{0}_samples_{1}.png".format(counter, i)),
                  merge(ren_img, [cfg['batch_size'] // 4, 4]))
          except:
              scipy.misc.imsave(
                  os.path.join(SAMPLE_DIR, "{0}_samples_{1}.png".format(counter, i)),
                  ren_img[0])

#=======================================================================================================================

  def sampling_Z(self, z_dim, type="uniform"):
      if str.lower(type) == "uniform":
          return np.random.uniform(-1., 1., (cfg['batch_size'], z_dim))
      else:
          return np.random.normal(0, 1, (cfg['batch_size'], z_dim))

  def linear_classifier(self, features, scope = "lin_class", stddev=0.02, reuse=False):
      with tf.variable_scope(scope) as sc:
          w = tf.get_variable('w', [features.get_shape()[-1], 1],
                              initializer=tf.random_normal_initializer(stddev=stddev))
          b = tf.get_variable('biases', 1, initializer=tf.constant_initializer(0.0))
          logits = tf.matmul(features, w) + b
          return   tf.nn.sigmoid(logits), logits

  def z_mapping_function(self, z, output_channel, scope='z_mapping', act="relu", stddev=0.02):
      with tf.variable_scope(scope) as sc:
          w = tf.get_variable('w', [z.get_shape()[-1], output_channel * 2],
                              initializer=tf.random_normal_initializer(stddev=stddev))
          b = tf.get_variable('biases', output_channel * 2, initializer=tf.constant_initializer(0.0))
          if act == "relu":
              out = tf.nn.relu(tf.matmul(z, w) + b)
          else:
              out = lrelu(tf.matmul(z, w) + b)
          return out[:, :output_channel], out[:, output_channel:]

#=======================================================================================================================
  def discriminator_IN(self, image,  cont_dim, reuse=False):
      if str(cfg["add_D_noise"]) == "true":
          image = image + tf.random_normal(tf.shape(image), stddev=0.02)

      with tf.variable_scope("discriminator") as scope:
          if reuse:
              scope.reuse_variables()

          h0 = lrelu(conv2d(image, self.df_dim, name='d_h0_conv'))
          h1 = lrelu(instance_norm(conv2d_specNorm(h0, self.df_dim * 2, name='d_h1_conv'),'d_in1'))
          h2 = lrelu(instance_norm(conv2d_specNorm(h1, self.df_dim * 4, name='d_h2_conv'),'d_in2'))
          h3 = lrelu(instance_norm(conv2d_specNorm(h2, self.df_dim * 8, name='d_h3_conv'),'d_in3'))

          #Returning logits to determine whether the images are real or fake
          h4 = linear(slim.flatten(h3), 1, 'd_h4_lin')

          # Recognition network for latent variables has an additional layer
          encoder = lrelu((linear(slim.flatten(h3), 128, 'd_latent')))
          cont_vars = linear(encoder, cont_dim, "d_latent_prediction")

          return tf.nn.sigmoid(h4), h4, tf.nn.tanh(cont_vars)

  def discriminator_IN_style_res128(self, image,  cont_dim, reuse=False):
      batch_size = tf.shape(image)[0]
      if str(cfg["add_D_noise"]) == "true":
          image = image + tf.random_normal(tf.shape(image), stddev=0.02)

      with tf.variable_scope("discriminator") as scope:
          if reuse:
              scope.reuse_variables()

          h0 = lrelu(conv2d(image, self.df_dim, name='d_h0_conv'))

          h1 = conv2d_specNorm(h0, self.df_dim * 2, name='d_h1_conv')
          h1, h1_mean, h1_var = instance_norm(h1, 'd_in1', True)
          h1_mean = tf.reshape(h1_mean, (batch_size, self.df_dim * 2))
          h1_var = tf.reshape(h1_var, (batch_size, self.df_dim * 2))
          d_h1_style = tf.concat([h1_mean, h1_var], 0)
          d_h1, d_h1_logits = self.linear_classifier(d_h1_style, "d_h1_class")
          h1 = lrelu(h1)

          h2 = conv2d_specNorm(h1, self.df_dim * 4, name='d_h2_conv')
          h2, h2_mean, h2_var = instance_norm(h2, 'd_in2', True)
          h2_mean = tf.reshape(h2_mean, (batch_size, self.df_dim * 4))
          h2_var = tf.reshape(h2_var, (batch_size, self.df_dim * 4))
          d_h2_style = tf.concat([h2_mean, h2_var], 0)
          d_h2, d_h2_logits = self.linear_classifier(d_h2_style, "d_h2_class")
          h2 = lrelu(h2)

          h3 = conv2d_specNorm(h2, self.df_dim * 8, name='d_h3_conv')
          h3, h3_mean, h3_var = instance_norm(h3, 'd_in3', True)
          h3_mean = tf.reshape(h3_mean, (batch_size, self.df_dim * 8))
          h3_var = tf.reshape(h3_var, (batch_size, self.df_dim * 8))
          d_h3_style = tf.concat([h3_mean, h3_var], 0)
          d_h3, d_h3_logits = self.linear_classifier(d_h3_style, "d_h3_class")
          h3 = lrelu(h3)

          h4 = conv2d_specNorm(h3, self.df_dim * 16, name='d_h4_conv')
          h4, h4_mean, h4_var = instance_norm(h4, 'd_in4', True)
          h4_mean = tf.reshape(h4_mean, (batch_size, self.df_dim * 16))
          h4_var = tf.reshape(h4_var, (batch_size, self.df_dim * 16))
          d_h4_style = tf.concat([h4_mean, h4_var], 0)
          d_h4, d_h4_logits = self.linear_classifier(d_h4_style, "d_h4_class")
          h4 = lrelu(h4)

          #Returning logits to determine whether the images are real or fake
          h5 = linear(slim.flatten(h4), 1, 'd_h5_lin')

          # Recognition network for latent variables has an additional layer
          encoder = lrelu((linear(slim.flatten(h4), 128, 'd_latent')))
          cont_vars = linear(encoder, cont_dim, "d_latent_prediction")

          return tf.nn.sigmoid(h5), h5, tf.nn.tanh(cont_vars), d_h1_logits, d_h2_logits, d_h3_logits, d_h4_logits

  def generator_AdaIN(self, z, view_in, reuse=False):
      batch_size = tf.shape(z)[0]
      s_h, s_w, s_d = 64, 64, 64
      s_h2, s_w2, s_d2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2), conv_out_size_same(s_d, 2)
      s_h4, s_w4, s_d4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2), conv_out_size_same(s_d2, 2)
      s_h8, s_w8, s_d8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2), conv_out_size_same(s_d4, 2)
      s_h16, s_w16, s_d16 = conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2), conv_out_size_same(s_d8, 2)

      with tf.variable_scope("generator") as scope:
          if reuse:
              scope.reuse_variables()
          #A learnt constant "template"
          with tf.variable_scope('g_w_constant'):
              w = tf.get_variable('w', [s_h16, s_w16, s_d16, self.gf_dim * 8], initializer=tf.random_normal_initializer(stddev=0.02))
              w_tile = tf.tile(tf.expand_dims(w, 0), (batch_size, 1, 1, 1, 1))
              s0, b0 = self.z_mapping_function(z, self.gf_dim * 8, 'g_z0')
              h0 = AdaIn(w_tile, s0, b0)
              h0 = tf.nn.relu(h0)

          h1= deconv3d(h0, [batch_size, s_h8, s_w8, s_d8, self.gf_dim * 2], k_h=3, k_d=3, k_w=3, name='g_h1')
          s1, b1 = self.z_mapping_function(z, self.gf_dim * 2, 'g_z1')
          h1 = AdaIn(h1, s1, b1)
          h1 = tf.nn.relu(h1)

          h2 = deconv3d(h1, [batch_size, s_h4, s_w4, s_d4, self.gf_dim * 1], k_h=3, k_d=3, k_w=3, name='g_h2')
          s2, b2 = self.z_mapping_function(z, self.gf_dim * 1, 'g_z2')
          h2 = AdaIn(h2, s2, b2)
          h2 = tf.nn.relu(h2)

          #=============================================================================================================
          h2_rotated = tf_3D_transform(h2, view_in, 16, 16)
          h2_rotated = transform_voxel_to_match_image(h2_rotated)
          #=============================================================================================================
          # Collapsing depth dimension
          h2_2d = tf.reshape(h2_rotated, [batch_size, s_h4, s_w4, 16 * self.gf_dim])
          # 1X1 convolution
          h3 = deconv2d(h2_2d, [batch_size, s_h4, s_w4, self.gf_dim * 16], k_h=1, k_w=1, d_h=1, d_w=1, name='g_h3')
          h3 = tf.nn.relu(h3)
          #=============================================================================================================

          h4  = deconv2d(h3, [batch_size, s_h2, s_w2, self.gf_dim * 4], k_h=4, k_w=4, name='g_h4')
          s4, b4 = self.z_mapping_function(z, self.gf_dim * 4, 'g_z4')
          h4  = AdaIn(h4, s4, b4)
          h4  = tf.nn.relu(h4)

          h5 = deconv2d(h4, [batch_size, s_h, s_w, self.gf_dim], k_h=4, k_w=4, name='g_h5')
          s5, b5 = self.z_mapping_function(z, self.gf_dim, 'g_z5')
          h5 = AdaIn(h5, s5, b5)
          h5 = tf.nn.relu(h5)

          h6 = deconv2d(h5, [batch_size, s_h, s_w, self.c_dim], k_h=4, k_w=4, d_h=1, d_w=1, name='g_h6')

          output = tf.nn.tanh(h6, name="output")
          return output

  def generator_AdaIN_res128(self, z, view_in, reuse=False):
      batch_size = tf.shape(z)[0]
      s_h, s_w, s_d = 64, 64, 64
      s_h2, s_w2, s_d2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2), conv_out_size_same(s_d, 2)
      s_h4, s_w4, s_d4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2), conv_out_size_same(s_d2, 2)
      s_h8, s_w8, s_d8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2), conv_out_size_same(s_d4, 2)
      s_h16, s_w16, s_d16 = conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2), conv_out_size_same(s_d8, 2)

      with tf.variable_scope("generator") as scope:
          if reuse:
              scope.reuse_variables()
          #A learnt constant "template"
          with tf.variable_scope('g_w_constant'):
              w = tf.get_variable('w', [s_h16, s_w16, s_d16, self.gf_dim * 8], initializer=tf.random_normal_initializer(stddev=0.02))
              w_tile = tf.tile(tf.expand_dims(w, 0), (batch_size, 1, 1, 1, 1)) #Repeat the learnt constant features to make a batch
              s0, b0 = self.z_mapping_function(z, self.gf_dim * 8, 'g_z0')
              h0 = AdaIn(w_tile, s0, b0)
              h0 = lrelu(h0)

          h1= deconv3d(h0, [batch_size, s_h8, s_w8, s_d8, self.gf_dim * 4], k_h=3, k_w=3, k_d=3, name='g_h1')
          s1, b1 = self.z_mapping_function(z, self.gf_dim * 4, 'g_z1')
          h1 = AdaIn(h1, s1, b1)
          h1 = lrelu(h1)

          h2 = deconv3d(h1, [batch_size, s_h4, s_w4, s_d4, self.gf_dim * 2],  k_h=3, k_w=3, k_d=3, name='g_h2')
          s2, b2 = self.z_mapping_function(z, self.gf_dim * 2, 'g_z2')
          h2 = AdaIn(h2, s2, b2)
          h2 = lrelu(h2)

          #=============================================================================================================
          h2_rotated = tf_3D_transform(h2, view_in, 16, 16)
          h2_rotated = transform_voxel_to_match_image(h2_rotated)

          h2_proj1 = deconv3d(h2_rotated, [batch_size, s_h4, s_w4, s_d4, self.gf_dim * 1], k_h=3, k_w=3, k_d=3, d_h=1, d_w=1, d_d=1, name='g_h2_proj1')
          h2_proj1 = lrelu( h2_proj1)

          h2_proj2 = deconv3d(h2_proj1, [batch_size, s_h4, s_w4, s_d4, self.gf_dim ], k_h=3, k_w=3, k_d=3, d_h=1, d_w=1, d_d=1,  name='g_h2_proj2')
          h2_proj2 = lrelu( h2_proj2)
          # =============================================================================================================
          # Collapsing depth dimension
          h2_2d = tf.reshape(h2_proj2, [batch_size, s_h4, s_w4, s_d4 * self.gf_dim])
          # 1X1 convolution
          h3 = deconv2d(h2_2d, [batch_size, s_h4, s_w4, self.gf_dim * 16 // 2], k_h=1, k_w=1, d_h=1, d_w=1, name='g_h3')
          h3 = lrelu(h3)
          # =============================================================================================================

          h4  = deconv2d(h3, [batch_size, s_h2, s_w2, self.gf_dim * 4],  k_h=4, k_w=4, name='g_h4')
          s4, b4 = self.z_mapping_function(z, self.gf_dim * 4, 'g_z4')
          h4  = AdaIn(h4, s4, b4)
          h4 = lrelu(h4)

          h5 = deconv2d(h4, [batch_size, s_h, s_w, self.gf_dim], k_h=4, k_w=4, name='g_h5')
          s5, b5 = self.z_mapping_function(z, self.gf_dim, 'g_z5')
          h5 = AdaIn(h5, s5, b5)
          h5 = lrelu(h5)

          h6 = deconv2d(h5, [batch_size, s_h * 2, s_w * 2, self.gf_dim // 2], k_h=4, k_w=4, name='g_h6')
          s6, b6 = self.z_mapping_function(z, self.gf_dim // 2, 'g_z6')
          h6 = AdaIn(h6, s6, b6)
          h6 = lrelu(h6)

          h7 = deconv2d(h6, [batch_size, s_h * 2, s_w * 2, self.c_dim], k_h=4, k_w=4, d_h=1, d_w=1, name='g_h7')

          output = tf.nn.tanh(h7, name="output")
          return output

#=======================================================================================================================
  @property
  def model_dir(self):
    return "{}_{}_{}".format(
        self.dataset_name,
        self.output_height, self.output_width)

  def save(self, checkpoint_dir, step):
    model_name = "HoloGAN.model"
    checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)

    if not os.path.exists(checkpoint_dir):
      os.makedirs(checkpoint_dir)

    self.saver.save(self.sess,
            os.path.join(checkpoint_dir, model_name),
            global_step=step)

  def load(self, checkpoint_dir):
    import re
    print(" [*] Reading checkpoints...")
    checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)

    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
      ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
      self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
      counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))
      print(" [*] Success to read {}".format(ckpt_name))
      return True, counter
    else:
      print(" [*] Failed to find a checkpoint")
      return False, 0


In [0]:
cfg = {
  "image_path":"./celebA",
  "gpu":0,
  "batch_size":32,
  "max_epochs":50,
  "epoch_step":25,
  "z_dim":128,
  "d_eta":0.0001,
  "g_eta":0.0001,
  "beta1":0.5,
  "beta2":0.999,
  "discriminator":"discriminator_IN",
  "generator":"generator_AdaIN",
  "view_func":"generate_random_rotation_translation",
  "train_func":"train_HoloGAN",
  "build_func":"build_HoloGAN",
  "style_disc":"false",
  "sample_z":"uniform",
  "add_D_noise":"false",
  "DStyle_lambda":1.0,

  "lambda_latent":0.0,
  "ele_low":70,
  "ele_high":110,
  "azi_low":220,
  "azi_high":320,
  "scale_low":1.0,
  "scale_high":1.0,
  "x_low":0,
  "x_high":0,
  "y_low":0,
  "y_high":0,
  "z_low":0,
  "z_high":0,
  "with_translation":"false",
  "with_scale":"false",
  "output_dir": "./HoloGAN"
}

In [0]:
import os
import json
import sys
import tensorflow as tf
import numpy as np


IMAGE_PATH = cfg['image_path']
OUTPUT_DIR = cfg['output_dir']
LOGDIR = os.path.join(OUTPUT_DIR, "log")

os.environ["CUDA_VISIBLE_DEVICES"] = "{0}".format(cfg['gpu'])


flags = tf.app.flags

flags.DEFINE_integer("input_height", 108, "The size of image to use (will be center cropped). [108] or [128] for celebA and lsun, [400] for chairs. Cats and Cars are already cropped")
flags.DEFINE_integer("input_width", None, "The size of image to use (will be center cropped). If None, same value as input_height [None]")
flags.DEFINE_integer("output_height", 64, "The size of the output images to produce 64 or 128")
flags.DEFINE_integer("output_width", None, "The size of the output images to produce. If None, same value as output_height [None]")
flags.DEFINE_string("dataset", "celebA", "The name of dataset [celebA, lsun, chairs, shoes, cars, cats]")
flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]")
flags.DEFINE_float("train_size", np.inf, "Number of images to train-Useful when only a subset of the dataset is needed to train the model")
flags.DEFINE_boolean("crop", True, "True for training, False for testing [False]")
flags.DEFINE_boolean("train", True, "True for training, False for testing [False]")
flags.DEFINE_boolean("rotate_azimuth", False, "Sample images with varying azimuth")
flags.DEFINE_boolean("rotate_elevation", False, "Sample images with varying elevation")
FLAGS = flags.FLAGS


def main(_):
  pp.pprint(flags.FLAGS.__flags)
  if FLAGS.input_width is None:
    FLAGS.input_width = FLAGS.input_height
  if FLAGS.output_width is None:
    FLAGS.output_width = FLAGS.output_height
  if not os.path.exists(LOGDIR):
    os.makedirs(LOGDIR)
  if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

  run_config = tf.ConfigProto()
  run_config.gpu_options.allow_growth=True
  print("FLAGs " + str(FLAGS.dataset))
  with tf.Session(config=run_config) as sess:
    model = HoloGAN(
        sess,
        input_width=FLAGS.input_width,
        input_height=FLAGS.input_height,
        output_width=FLAGS.output_width,
        output_height=FLAGS.output_height,
        dataset_name=FLAGS.dataset,
        input_fname_pattern=FLAGS.input_fname_pattern,
        crop=FLAGS.crop)

    model.build(cfg['build_func'])

    show_all_variables()

    if FLAGS.train:
        train_func = eval("model." + (cfg['train_func']))
        train_func(FLAGS)
    else:
      if not model.load(LOGDIR)[0]:
        raise Exception("[!] Train a model first, then run test mode")
      model.sample_HoloGAN(FLAGS)


In [0]:
tf.app.run()