# 3D GAN

This is an attempt to re-implement the paper 3D Gan

Paper: http://3dgan.csail.mit.edu/papers/3dgan_nips.pdf

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

In [2]:
class Linear(tf.keras.layers.Layer):
    def __init__(self, neurons, **kwargs):
        super().__init__(**kwargs)
        self.neurons = neurons
        
    def build(self, input_shape):
        inp_neurons = input_shape[-1]
        
        init = tf.keras.initializers.RandomNormal(mean = 0.0, stddev = 1.0)
        self.W = self.add_weight(shape = (inp_neurons, self.neurons), initializer = init, 
                                 trainable = True, name = 'weight')
        self.B = self.add_weight(shape = (1, self.neurons), initializer = 'zeros', 
                                 trainable = True, name = 'bias')
        
    def call(self, inputs):
        return tf.add(tf.matmul(inputs, self.W), self.B)

In [3]:
class Conv3D(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, strides, **kwargs):
        super().__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        
    def build(self, input_shape):
        inp_filters = input_shape[-1]
        
        init = tf.keras.initializers.RandomNormal(mean = 0.0, stddev = 1.0)
        self.W = self.add_weight(shape = self.kernel_size + (inp_filters, self.filters), initializer = init, 
                                 trainable = True, name = 'weight')
        self.B = self.add_weight(shape = (self.filters, ), initializer = 'zeros', 
                                 trainable = True, name = 'bias')
        
    @property
    def padding(self):
        d = (self.kernel_size[0] - 1)//2
        h = (self.kernel_size[1] - 1)//2
        w = (self.kernel_size[1] - 1)//2
        return [[0, 0], [d, d], [h, h], [w, w], [0, 0]]
        
    def call(self, inputs):
        x = tf.pad(inputs, self.padding, mode = 'REFLECT')
        return tf.add(tf.nn.conv3d(x, self.W, self.strides, padding = 'VALID', data_format = 'NDHWC'), self.B)

In [23]:
class GAN(tf.keras.models.Model):
    def __init__(self, latent_dim = 200, **kwargs):
        super().__init__(**kwargs)
        self.latent_dim = latent_dim
        self.generator = self.__init_generator
        self.gen_shp = self.generator.output_shape
        self.discriminator = self.__init_discriminator
    
    def call(self, inputs):
        return self.generator(inputs)
    
    @property
    def __init_generator(self):
        inp = tf.keras.layers.Input(shape = (self.latent_dim, ), dtype = tf.float32, name = 'generator_input')
        
        x = Linear(neurons = 512 * 4 * 4 * 4)(inp)
        x = tf.keras.layers.Reshape((4, 4, 4, 512))(x)
        
        x = tf.keras.layers.Conv3DTranspose(filters = 512, kernel_size = (4, 4, 4), strides = (1, 1, 1), 
                                            padding = 'same')(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.ReLU()(x)
        
        x = tf.keras.layers.Conv3DTranspose(filters = 256, kernel_size = (4, 4, 4), strides = (2, 2, 2), 
                                            padding = 'same')(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.ReLU()(x)
        
        x = tf.keras.layers.Conv3DTranspose(filters = 128, kernel_size = (4, 4, 4), strides = (2, 2, 2), 
                                            padding = 'same')(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.ReLU()(x)
        
        x = tf.keras.layers.Conv3DTranspose(filters = 64, kernel_size = (4, 4, 4), strides = (2, 2, 2), 
                                            padding = 'same')(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.ReLU()(x)
        
        x = tf.keras.layers.Conv3DTranspose(filters = 1, kernel_size = (4, 4, 4), strides = (2, 2, 2), 
                                            padding = 'same')(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Activation('sigmoid')(x)
        
        return tf.keras.models.Model(inp, x, name = 'generator')
    
    @property
    def __init_discriminator(self):
        inp = tf.keras.layers.Input(shape = self.gen_shp[1:], dtype = tf.float32, name = 'discriminator_input')
        
        x = tf.keras.layers.Conv3D(filters = 64, kernel_size = (4, 4, 4), strides = 2, padding = 'same')(inp)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        
        x = tf.keras.layers.Conv3D(filters = 128, kernel_size = (4, 4, 4), strides = 2, padding = 'same')(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        
        x = tf.keras.layers.Conv3D(filters = 256, kernel_size = (4, 4, 4), strides = 2, padding = 'same')(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        
        x = tf.keras.layers.Conv3D(filters = 512, kernel_size = (4, 4, 4), strides = 2, padding = 'same')(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        
        x = tf.keras.layers.Conv3D(filters = 1, kernel_size = (4, 4, 4), strides = 1, padding = 'same')(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Activation('relu')(x)
        
        x = tf.keras.layers.Flatten()(x)
        x = Linear(neurons = 1)(x)
        #x = tf.keras.layers.Activation('sigmoid')(x)
        
        return tf.keras.models.Model(inp, x, name = 'discriminator')
    
    def compile(self):
        super().compile()
        self.gen_optimizer = tf.keras.optimizers.Adam(learning_rate = 0.025, beta_1 = 0.5)
        self.disc_optimizer = tf.keras.optimizers.Adam(learning_rate = 1e-5, beta_1 = 0.5)
        self.gan_loss = tf.keras.losses.BinaryCrossentropy(from_logits = True)
        
        
    def latent_noise(self, batch_size):
        return tf.random.normal((batch_size, self.latent_dim))
        
    def train_step(self, img_3d):
        if isinstance(img_3d, tuple):
            img_3d = img_3d[0]
        batch_size = tf.shape(img_3d)[0]
            
        with tf.GradientTape() as disc_tape, tf.GradientTape() as gen_tape:
            gen_3d = self.generator(self.latent_noise(batch_size), training = True)
            
            disc_real_out = self.discriminator(img_3d, training = True)
            disc_gen_out = self.discriminator(gen_3d, training = True)
            
            disc_loss = (self.gan_loss(tf.ones_like(disc_real_out), disc_real_out) + 
            self.gan_loss(tf.zeros_like(disc_gen_out), disc_gen_out)) * 0.5
            gen_loss = self.gan_loss(tf.ones_like(disc_gen_out), disc_gen_out)
            
        
        gen_grads = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
        self.gen_optimizer.apply_gradients(zip(gen_grads, self.generator.trainable_variables))
            
        disc_grads = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)
        self.disc_optimizer.apply_gradients(zip(disc_grads, self.discriminator.trainable_variables))
        
        return {'gen_loss': gen_loss, 'disc_loss': disc_loss}
    
    
    def train(self, data, epochs = 1):
        for e in range(epochs):
            print(f'Epoch: {e} Starts.')
            for img3d in data:
                loss = self.train_step(img3d)
                print('.', end='')
                
            print(f'\nLoss: {loss}')
            print(f'Epoch: {e} Ends.\n')
            

In [24]:
g = GAN()
g.compile()

In [25]:
'''inp = `your input`'''
g.fit(inp)



<tensorflow.python.keras.callbacks.History at 0x1f3ac75d5b0>