# Style Gan 2

This is an attempt to re-implement the paper Style-GAN 2

Paper: https://arxiv.org/pdf/1912.04958.pdf

Other Resources:
* https://github.com/NVlabs/stylegan2
* https://nn.labml.ai/gan/stylegan

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

In [2]:
class MinibatchStddev(tf.keras.layers.Layer):
    '''
    Mini Batch Standard Deviation
    
    Args:
        epsilon (float): small value to avoid division by 0. | default -> 1e-8
        
    Input: (Feature Maps)
        Input is a single 4-Dimensional Tensor. Each dimension indicating
        (batch_size, height, width, feature_maps)
        
    Output: (Feature Maps)
        Output is a single 4-Dimensional Tensor. Each dimension indicating
        (batch_size, height, width, feature_maps + 1)
        
        extra feature is added which is the mini batch standard deviation (along axis 0)
    
    '''
    def __init__(self, epsilon: float = 1e-8, **kwargs):
        super().__init__(**kwargs)
        self.epsilon = epsilon
        
    def call(self, inputs):
        inp_shp = tf.shape(inputs)
        
        mean = tf.reduce_mean(inputs, keepdims = True, axis = 0)
        std = tf.sqrt(tf.math.reduce_mean(tf.square(inputs - mean), keepdims = True, axis = 0) + self.epsilon)
        avg_std = tf.reduce_mean(std, keepdims = True)
        tiled = tf.tile(avg_std, (inp_shp[0], inp_shp[1], inp_shp[2], 1))
        combined = tf.concat([inputs, tiled], axis = -1)
        return combined

In [3]:
class Linear(tf.keras.layers.Layer):
    '''
    Linear Layer
    
    Args:
        neurons (int): Define the number of neurons in the layer here.
        gain (float): Define gain here for weight scaling. | default -> np.sqrt(2)
        
    Inputs: (Linear)
        Input is a single 4-Dimensional Tensor. Each dimension indicating
        (batch_size, input_neurons)
        
    Outputs: (Linear)
        output is a single 4-Dimensional Tensor. Each dimension indicating
        (batch_size, neurons)
        
    '''
    def __init__(self, neurons: int, gain: float = np.sqrt(2), **kwargs):
        super().__init__(**kwargs)
        self.neurons = neurons
        self.gain = gain
        
    def build(self, input_shape: tf.TensorShape):
        inp_neurons = input_shape[-1]
        
        init = tf.keras.initializers.RandomNormal(mean = 0.0, stddev = 1.0)
        self.W = self.add_weight(shape = (inp_neurons, self.neurons), initializer = init, 
                                 trainable = True, name = 'Weight')
        self.B = self.add_weight(shape = (self.neurons, ), initializer = 'zeros', 
                                 trainable = True, name = 'Bias')
        
        self.w_scale = self.gain * tf.math.rsqrt(tf.cast(inp_neurons, tf.float32))
        
    def call(self, inputs):
        return tf.add(tf.matmul(inputs, self.W * self.w_scale), self.B)

In [5]:
class Conv2D(tf.keras.layers.Layer):
    '''
    Convolution 2-Dimensional Layer
    
    Args:
        filters (int): Define the number of output feature_maps/filters here.
        kernel_size (tuple): Define the size of kernel here. | default-> (3, 3)
        strides (tuple): Define the size of strides here. | default -> (1, 1)
        gain (float): Define gain here for weight scaling. | default -> np.sqrt(2)
    
    Inputs: (Feature Maps)
        Input is a single 4-Dimensional Tensor. Each dimension indicating
        (batch_size, height, width, feature_maps)
        
    Outputs: (Feature Maps)
        Output is a single 4-Dimensional Tensor. Each dimension indicating
        (batch_size, height, width, filters)
        
    '''
    def __init__(self, filters: int, kernel_size: tuple = (3, 3), strides: tuple = (1, 1), gain: float = 1e-8, 
                 **kwargs):
        super().__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        self.gain = gain
        self.padding = 'SAME' if (kernel_size[0]-1)//2 else 'VALID'
        
    def build(self, input_shape: tf.TensorShape):
        inp_chn = input_shape[-1]
        
        init = tf.keras.initializers.RandomNormal(mean = 0.0, stddev = 1.0)
        self.W = self.add_weight(shape = (self.kernel_size[0], self.kernel_size[1], inp_chn, self.filters), 
                                 initializer = init, trainable = True, name = 'Weight')
        self.B = self.add_weight(shape = (self.filters, ), initializer = 'zeros', 
                                 trainable = True, name = 'Bias')
        
        fan_in = tf.cast(self.kernel_size[0] * self.kernel_size[1] * inp_chn, tf.float32)
        self.w_scale = self.gain * tf.math.rsqrt(fan_in)
        
    def call(self, inputs):
        return tf.add(tf.nn.conv2d(inputs, self.W * self.w_scale, self.strides, self.padding, 'NHWC'), self.B)

In [6]:
class Conv2DMod(tf.keras.layers.Layer):
    '''
    Convolutional 2-Dimensional Layer with Modulated & Demodulated Weights
    
    Args:
        filters (int): Define the number of output feature_maps/filters here.
        kernel_size (tuple): Define the size of kernel here. | default-> (3, 3)
        strides (tuple): Define the size of strides here. | default -> (1, 1)
        demodulate (bool): Define whether to demodulate the weights. | default -> True
        epsilon (float): small value to avoid division by 0. | default -> 1e-8
        
    Inputs:
        Inputs are :-
            (1) `Feature Maps`: 4-Dimensional Tensor. Each dimension indicating (batch_size, height, width, feature_map)
            (2) `Mapped Output`: 2-Dimensional Tensor. Each dimension indicating (batch_size, latent_dim)
            
    Outputs: (Feature Maps)
        Output is a 4-Dimensional Tensor. Each dimension indicating
        (batch_size, height, width, filters)
        
    '''
    def __init__(self, filters: int, kernel_size: tuple = (3, 3), strides: tuple = (1, 1), demodulate: bool = True, 
                 epsilon: float = 1e-8, **kwargs):
        super().__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = 'SAME' if (kernel_size[0]-1)//2 else 'VALID'
        self.demodulate = demodulate
        self.epsilon = epsilon
        
    def build(self, input_shapes):
        inp_chn = input_shapes[0][-1]
        
        init = tf.keras.initializers.RandomNormal(mean = 0.0, stddev = 1.0)
        self.W = self.add_weight(shape = (self.kernel_size[0], self.kernel_size[1], inp_chn, self.filters), 
                                 initializer = init, trainable = True, name = 'Weight')
        self.linear = Linear(neurons = inp_chn)
        
    def call(self, inputs):
        x, m = inputs
                
        s = self.linear(m)[:, tf.newaxis, tf.newaxis, :, tf.newaxis]
        w = tf.expand_dims(self.W, axis = 0)
        w *= s
        
        if self.demodulate:
            w *= tf.math.rsqrt(tf.math.reduce_sum(tf.square(w), keepdims = True, axis = [1, 2, 3]) + self.epsilon)
            
        x = tf.transpose(x, perm = [0, 3, 1, 2])
        x = tf.reshape(x, (1, -1, tf.shape(x)[2], tf.shape(x)[3]))
        
        w = tf.transpose(w, perm = [1, 2, 3, 0, 4])
        w = tf.reshape(w, (tf.shape(w)[0], tf.shape(w)[1], tf.shape(w)[2], -1))
        
        out = tf.nn.conv2d(x, w, self.strides, self.padding, 'NCHW')
        
        out = tf.reshape(out, (-1, self.filters, tf.shape(out)[2], tf.shape(out)[3]))
        out = tf.transpose(out, perm = [0, 2, 3, 1])
        return out

In [7]:
class AddNoise(tf.keras.layers.Layer):
    '''
    Add Noise
    
    Args:
        None
        
    Inputs:
        Inputs are :-
            (1) `Feature Maps`: 4-Dimensional Tensor. Each dimensions indicating (batch_size, height, width, feature_map)
            (2) `Noise`: 4-Dimensional Tensor. Each dimensions indicating (batch_size, height, width, 1)
            
    Output: (Feature Maps)
        Output is a 4-Dimensional Tensor. Each dimensions indicating 
        (batch_size, height, width, feature_map)
    '''
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        pass
    
    def build(self, input_shape):
        inp_chn = input_shape[0][-1]
        
        init = tf.keras.initializers.RandomNormal(mean = 0.0, stddev = 1.0)
        self.W = self.add_weight(shape = (1, 1, 1, inp_chn), initializer = init, 
                                 trainable = True, name = 'Noise_Weight')
        
    def call(self, inputs):
        return tf.add(inputs[0], tf.multiply(self.W, inputs[1]))

In [8]:
class AddBias(tf.keras.layers.Layer):
    '''
    Add Bias
    
    Args:
        None
        
    Inputs: (Feature Maps)
        Input is a 4-Dimensional Tensor. Each dimensions indicating 
        (batch_size, height, width, feature_map)
        
    Outputs: (Feature Maps)
        Output is a 4-Dimensional Tensor. Each dimensions indicating
        (batch_size, height, width, feature_map)
        
    '''
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        pass
    
    def build(self, input_shape):
        inp_chn = input_shape[-1]
        self.B = self.add_weight(shape = (1, 1, 1, inp_chn), initializer = 'zeros', 
                                 trainable = True, name = 'Bias')
        
    def call(self, inputs):
        return tf.add(inputs, self.B)

In [9]:
class ToRGB(tf.keras.layers.Layer):
    '''
    Convert To RGB Image (with 3 channels)
    
    Args:
        None
        
    Inputs:
        Inputs are :-
            (1) `Feature maps`: 4-Dimensional Tensor. Each dimensions indicating (batch_size, height, width, feature_map)
            (2) `Latent Dimension`: 2-Dimensional Tensor. Each dimensions indicating (batch_size, latent_dim)
        
    Outputs:
        Output is a 4-Dimensional Tensor. Each dimension indicating
        (batch_size, height, width, 3)
    '''
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.conv = Conv2DMod(filters = 3, kernel_size = (1, 1), strides = (1, 1), demodulate = False)
        self.act = tf.keras.layers.LeakyReLU(alpha = 0.2)
        self.add_bias = AddBias()
    
    def call(self, inputs):
        return self.act(self.add_bias(self.conv(inputs)))

In [10]:
class FromRGB(tf.keras.layers.Layer):
    def __init__(self, filters: int, **kwargs):
        super().__init__(**kwargs)
        self.conv = Conv2D(filters = filters, kernel_size = (1, 1), strides = (1, 1), gain = 1.0)
        self.act = tf.keras.layers.LeakyReLU(alpha = 0.2)
        
    def call(self, inputs):
        return self.act(self.conv(inputs))

In [11]:
class StyleBlock(tf.keras.layers.Layer):
    '''
    Style Block
    
    Args:
        filters (int): Define the number of output feature_maps/filters here.
        
    Inputs:
        Inputs are :-
            (1) `Feature maps`: 4-Dimensional Tensor. Each dimensions indicating (batch_size, height, width, feature_map)
            (2) `Mapped Output`: 2-Dimensional Tensor. Each dimensions indicating (batch_size, latent_dim)
            (3) `Noise`: 4-Dimensional Tensor. Each dimensions indicating (batch_size, height, width, 1)
            
    Outputs: (Feature Maps)
        Output is a 4-Dimensional Tensor. Each dimension indicating
        (batch_size, height, width, filters)
            
    '''
    def __init__(self, filters: int, **kwargs):
        super().__init__(**kwargs)
        self.conv = Conv2DMod(filters = filters, kernel_size = (3, 3), strides = (1, 1), demodulate = True)
        self.add_noise = AddNoise()
        self.add_bias = AddBias()
        self.act = tf.keras.layers.LeakyReLU(alpha = 0.2)
        
    def call(self, inputs):
        x, m, n = inputs
        return self.act(self.add_bias(self.add_noise([self.conv([x, m]), n])))

In [12]:
class SkipGeneratorBlock(tf.keras.layers.Layer):
    '''
    Skip Generator Block
    
    Args:
        filters (int): Define the number of output feature_maps/filters here.
        
    Inputs:
        Inputs are :-
            (1) `Feature Maps`: 4-Dimensional Tensor. Each dimensions indicating (batch_size, height, width, feature_map)
            (2) `Mapped Output`: 2-Dimensional Tensor. Each dimensions indicating (batch_size, latent_dim)
            (3) `Noise`: 2-Dimensional Tensor. Each dimensions indicating (batch_size, latent_dim)
            (4) `Previous RGB Output`: 4-Dimensional Tensor. Each dimensions indicating (batch_size, height, width, 3)
            
    Outputs:
        Outputs are :-
            (1) `Feature Maps`: 4-Dimensional Tensor. Each dimensions indicating (batch_size*2, height, width*2, filters)
            (2) `RGB`: 4-Dimensional Tensor. Each dimensions indicating (batch_size, height*2, width*2, 3)
            
    '''
    def __init__(self, filters: int, **kwargs):
        super().__init__(**kwargs)
        self.up_sample = tf.keras.layers.UpSampling2D(size = (2, 2), interpolation = 'bilinear')
        self.up_sample_rgb = tf.keras.layers.UpSampling2D(size = (2, 2), interpolation = 'bilinear')
        
        self.style_block_1 = StyleBlock(filters = filters)
        self.style_block_2 = StyleBlock(filters = filters)
        
        self.to_rgb = ToRGB()
    
    def call(self, inputs):
        x, m, n, prev_rgb = inputs
        
        prev_rgb = self.up_sample_rgb(prev_rgb)
        
        out = self.up_sample(x)
        out = self.style_block_1([out, m, n])
        out = self.style_block_2([out, m, n])
        
        rgb = self.to_rgb([out, m])
        next_rgb = tf.add(rgb, prev_rgb)
        
        return out, next_rgb

In [13]:
class SkipDiscriminatorBlock(tf.keras.layers.Layer):
    '''
    Skip Discriminator Block
    
    Args:
        filters (int): Define the number of output feature_maps/filters here.
        
    Inputs: 
        Inputs are :-
            (1) `Feature Maps`: 4-Dimensional Tensor. Each dimensions indicating (batch_size, height, width, feature_map)
            (2) `Feature Maps`: 4-Dimensional Tensor. Each dimensions indicating (batch_size, height, width, feature_maps)
            
    Outputs:
        Outputs are :-
            (1) `Feature Maps`: 4-Dimensional Tensor. Each dimensions indicating (batch_size, height//2, width//2, filters)
            (2) `Feature Maps`: 4-Dimensional Tensor. Each dimensions indicating (batch_size, height//2, width//2, filters)
    '''
    def __init__(self, filters: int, first_block: bool = False, **kwargs):
        super().__init__(**kwargs)
        self.first_block = first_block
        
        self.from_rgb = FromRGB(filters = filters)
        self.conv_1 = Conv2D(filters = filters, kernel_size = (3, 3), strides = (1, 1))
        self.act_1 = tf.keras.layers.LeakyReLU(alpha = 0.2)
        self.conv_2 = Conv2D(filters = filters, kernel_size = (3, 3), strides = (1, 1))
        self.act_2 = tf.keras.layers.LeakyReLU(alpha = 0.2)
        
        if not self.first_block:
            self.down_sample_d = tf.keras.layers.AveragePooling2D()
            self.conv_d = Conv2D(filters = filters, kernel_size = (1, 1), strides = (1, 1))
            self.act_d = tf.keras.layers.LeakyReLU(alpha = 0.2)
        
        self.down_sample_1 = tf.keras.layers.AveragePooling2D()
        self.down_sample_2 = tf.keras.layers.AveragePooling2D()
        
    def call(self, inputs):
        img = inputs

        out_1 = self.down_sample_1(img[0])
        x1 = self.from_rgb(img[0])
        x1 = self.act_1(self.conv_1(x1))
        x1 = self.act_2(self.conv_2(x1))
        x1 = self.down_sample_2(x1)
        if not self.first_block:
            x2 = self.down_sample_d(img[1])
            x2 = self.act_d(self.conv_d(x2))
            out_2 = tf.add(x1, x2)
        else:
            out_2 = x1

        return out_1, out_2

In [14]:
class ResidualGeneratorBlock(tf.keras.layers.Layer):
    '''
    Residual Generator Block
    
    Args:
        filters (int): Define the number of output feature_maps/filters here.
        
    Inputs:
        Inputs are :-
            (1) `Feature Maps`: 4-Dimensional Tensor. Each dimensions indicating (batch_size, height, width, feature_map)
            (2) `Mapped Output`: 2-Dimensional Tensor. Each dimensions indicating (batch_size, latent_dim)
            (3) `Noise`: 2-Dimensional Tensor. Each dimensions indicating (batch_size, latent_dim)
        
    Outputs: (Feature Maps)
        Output is a 4-Dimensional Tensor. Each dimensions indicating (batch_size, height*2, width*2, filters)
        
    '''
    def __init__(self, filters: int, **kwargs):
        super().__init__(**kwargs)
        self.up_sample_1 = tf.keras.layers.UpSampling2D(size = (2, 2), interpolation = 'bilinear')
        self.up_sample_2 = tf.keras.layers.UpSampling2D(size = (2, 2), interpolation = 'bilinear')
        
        self.style_block_1 = StyleBlock(filters = filters)
        self.style_block_2 = StyleBlock(filters = filters)
        
        #self.conv = Conv2DMod(filters = filters, kernel_size = (1, 1), strides = (1, 1), demodulate = False)
        self.conv = Conv2D(filters = filters, kernel_size = (1, 1), strides = (1, 1), gain = 1.0)
        self.act = tf.keras.layers.LeakyReLU(alpha = 0.2)
        
    def call(self, inputs):
        x, m, n = inputs
        
        x1 = self.up_sample_1(x)
        x1 = self.act(self.conv([x1, m]))
        
        x2 = self.up_sample_2(x)
        x2 = self.style_block_1([x2, m, n])
        x2 = self.style_block_2([x2, m, n])
        
        out = tf.add(x1, x2)
        return out

In [15]:
class ResidualDiscriminatorBlock(tf.keras.layers.Layer):
    '''
    Residual Discriminator Block
    
    Args:
        filters (int): Define the number of output feature_maps/filters here.
        
    Inputs: (Feature Maps)
        Input is a 4-Dimensional Tensor. Each dimensions indicating 
        (batch_size, height, width, feature_map)
        
    Outputs: (Featute Maps)
        Output is a 4-Dimensional Tensor. Each dimensions indicating 
        (batch_size, height//2, width//2, feature_map)
        
    '''
    def __init__(self, filters: float, **kwargs):
        super().__init__(**kwargs)
        self.conv_1 = Conv2D(filters = filters, kernel_size = (3, 3), strides = (1, 1), gain = np.sqrt(2))
        self.act_1 = tf.keras.layers.LeakyReLU(alpha = 0.2)
        self.conv_2 = Conv2D(filters = filters, kernel_size = (3, 3), strides = (1, 1), gain = np.sqrt(2))
        self.act_2 = tf.keras.layers.LeakyReLU(alpha = 0.2)
        self.down_sample = tf.keras.layers.AveragePooling2D()
        
        self.down_sample_r = tf.keras.layers.AveragePooling2D()
        self.conv_r = Conv2D(filters = filters, kernel_size = (1, 1), strides = (1, 1), gain = np.sqrt(2))
        self.act_r = tf.keras.layers.LeakyReLU(alpha = 0.2)
        
        self.scale = tf.math.rsqrt(2.0)
        
    def call(self, inputs):
        
        x1 = self.act_1(self.conv_1(inputs))
        x1 = self.act_2(self.conv_2(x1))
        x1 = self.down_sample(x1)
        
        x2 = self.act_r(self.conv_r(self.down_sample_r(inputs)))
        
        out = tf.add(x1, x2) * self.scale
        return out

In [16]:
class Normalize(tf.keras.layers.Layer):
    '''
    Normalize
    
    Args:
        p (int): Define the exponent value in the norm formulation. | default -> 2
        axis (int): Define the dimension to reduce. | default -> 1
        epsilon (float): Define small value to avoid division by zero. | default -> 1e-8
        
    Inputs: (Linear)
        Input is a 2-Dimensional Tensor. Each dimensions indicating
        (batch_size, latent_dim)
        
    Outputs: (Linear)
        Output is a 2-Dimensional Tensor. Each dimensions indicating
        (batch_size, latent_dim)
        
    '''
    def __init__(self, p = 2, axis = 1, epsilon: float = 1e-8, **kwargs):
        super().__init__(**kwargs)
        if p == 1:
            self.norm = lambda x: x * tf.math.rsqrt(tf.reduce_sum(tf.abs(x), keepdims = True, axis = axis) + epsilon)
        elif p == 2:
            self.norm = lambda x: x * tf.math.rsqrt(tf.reduce_sum(tf.square(x), keepdims = True, axis = axis) + epsilon)
        else:
            raise ValueError('`p` value should be 1 or 2.\n\t1 indicates L1 Norm.\n\t2 indicates L2 Norm.')
        
    def call(self, inputs):
        return self.norm(inputs)

In [17]:
class LinearMapping(tf.keras.layers.Layer):
    '''
    Linear Mapping
    
    Args:
        latent_dim (int): Define the latent dimension. | default -> 512
        
    Inputs: (Linear)
        Input is a 2-Dimensional Tensor. Each dimensions indicating
        (batch_size, latent_dim)
        
    Outputs: (Linear)
        Output is a 2-Dimensional Tensor. Each dimensions indicating
        (batch_size, latent_dim)
        
    '''
    def __init__(self, latent_dim: int = 512, **kwargs):
        super().__init__(**kwargs)
        self.norm = Normalize(p = 2, axis = 1)
        
        self.linear_1 = Linear(neurons = latent_dim)
        self.act_1 = tf.keras.layers.LeakyReLU(alpha = 0.2)
        self.linear_2 = Linear(neurons = latent_dim)
        self.act_2 = tf.keras.layers.LeakyReLU(alpha = 0.2)
        self.linear_3 = Linear(neurons = latent_dim)
        self.act_3 = tf.keras.layers.LeakyReLU(alpha = 0.2)
        self.linear_4 = Linear(neurons = latent_dim)
        self.act_4 = tf.keras.layers.LeakyReLU(alpha = 0.2)
        self.linear_5 = Linear(neurons = latent_dim)
        self.act_5 = tf.keras.layers.LeakyReLU(alpha = 0.2)
        self.linear_6 = Linear(neurons = latent_dim)
        self.act_6 = tf.keras.layers.LeakyReLU(alpha = 0.2)
        self.linear_7 = Linear(neurons = latent_dim)
        self.act_7 = tf.keras.layers.LeakyReLU(alpha = 0.2)
        self.linear_8 = Linear(neurons = latent_dim)
        self.act_8 = tf.keras.layers.LeakyReLU(alpha = 0.2)
        
    def call(self, inputs):
        x = self.norm(inputs)
        
        x = self.act_1(self.linear_1(x))
        x = self.act_2(self.linear_2(x))
        x = self.act_3(self.linear_3(x))
        x = self.act_4(self.linear_4(x))
        x = self.act_5(self.linear_5(x))
        x = self.act_6(self.linear_6(x))
        x = self.act_7(self.linear_7(x))
        x = self.act_8(self.linear_8(x))
        
        return x

In [18]:
class GAN(tf.keras.models.Model):
    def __init__(self, latent_dim: int = 512, gen_res: int = 1024, d_steps: int = 1, drift_weight: float = 0.001, 
                 seperate_latent: bool = True, gp_weight: float = 10.0, **kwargs):
        super().__init__(**kwargs)
        self.latent_dim = latent_dim
        self.gen_res = gen_res
        self.num_block = int(np.log2(gen_res) - 1)
        self.seperate_latent = seperate_latent
        self.gp_weight = gp_weight
        self.d_steps = d_steps
        self.drift_weight = drift_weight
        self.FILTERS = [512, 512, 512, 512, 256, 128, 64, 32, 16]
        
    def call(self, inputs):
        return
        
    def __init_generator(self):
        
        if self.seperate_latent:
            latent_inp = [
                tf.keras.layers.Input(shape = (self.latent_dim), dtype = tf.float32, name = f'latent_inp_{i}')
                for i in range(self.num_block)
            ]
        else:
            latent_inp = tf.keras.layers.Input(shape = (self.latent_dim), dtype = tf.float32, name = 'latent_inp')
            
        const_inp = tf.keras.layers.Input(shape = (4, 4, 512), dtype = tf.float32, name = 'constant_input_4x4x512')
        
        noise_inp = [
            tf.keras.layers.Input(shape = (2**(i+2), 2**(i+2), 1), dtype = tf.float32, 
                                  name = f'noise_input_{2**(i+2)}x{2**(i+2)}') for i in range(self.num_block)
        ]
        
        
        linear_mapping = LinearMapping(self.latent_dim)
        if self.seperate_latent:
            mapped = []
            for lp in latent_inp:
                mapped.append(linear_mapping(lp))
        else:
            mapped = linear_mapping(latent_inp)
            
            
        if self.gen_type == 'skip':
            
            if self.seperate_latent:
                x = StyleBlock(filters = self.FILTERS[0])([const_inp, mapped[0], noise_inp[0]])
                rgb = ToRGB()([x, mapped[0]])
                
                for i in range(1, self.num_block):
                    x, rgb = SkipGeneratorBlock(filters = self.FILTERS[i])([x, mapped[i], noise_inp[i], rgb])
                    
                return tf.keras.models.Model(latent_inp + [const_inp] + noise_inp, rgb, 
                                             name = f'skip_generator_{self.gen_res}x{self.gen_res}')
            
            else:
                x = StyleBlock(filters = self.FILTERS[0])([const_inp, mapped, noise_inp[0]])
                rgb = ToRGB()([x, mapped])
                
                for i in range(1, self.num_block):
                    x, rgb = SkipGeneratorBlock(filters = self.FILTERS[i])([x, mapped, noise_inp[i], rgb])
                    
                return tf.keras.models.Model([latent_inp, const_inp] + noise_inp, rgb, 
                                             name = f'skip_generator_{self.gen_res}x{self.gen_res}')
            
        elif self.gen_type == 'residual':
            
            if self.seperate_latent:
                x = StyleBlock(filters = self.FILTERS[0])([const_inp, mapped[0], noise_inp[0]])
                
                for i in range(1, self.num_block):
                    x = ResidualGeneratorBlock(filters = self.FILTERS[i])([x, mapped[i], noise_inp[i]])
                
                out = ToRGB()([x, mapped[i]])
                return tf.keras.models.Model(latent_inp + [const_inp] + noise_inp, out, 
                                             name = f'residual_generator_{self.gen_res}x{self.gen_res}')
            
            else:
                x = StyleBlock(filters = self.FILTERS[0])([const_inp, mapped, noise_inp[0]])
                
                for i in range(1, self.num_block):
                    x = ResidualGeneratorBlock(filters = self.FILTERS[i])([x, mapped, noise_inp[i]])
                
                out = ToRGB()([x, mapped])
                return tf.keras.models.Model([latent_inp, const_inp] + noise_inp, out, 
                                             name = f'residual_generator_{self.gen_res}x{self.gen_res}')
            
    def __init_discriminator(self):
        
        inp = tf.keras.layers.Input(shape = (self.gen_res, self.gen_res, 3), dtype = tf.float32, 
                                    name = f'discriminator_input_{self.gen_res}x{self.gen_res}x3')
        
        
        if self.disc_type == 'residual':
            x = FromRGB(filters = self.FILTERS[-1])(inp)
            
            for i in range(self.num_block-1, 0, -1):
                x = ResidualDiscriminatorBlock(filters = self.FILTERS[i])(x)
                
            x = MinibatchStddev()(x)
            x = Conv2D(filters = self.FILTERS[0], kernel_size = (3, 3), strides = (2, 2))(x)
            x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
            
            x = tf.keras.layers.Flatten()(x)
            x = Linear(neurons = 1, gain = 1.0)(x)
            return tf.keras.models.Model(inp, x, name = f'residual_discriminator_{self.gen_res}x{self.gen_res}')
        
        elif self.disc_type == 'skip':
            out_1, out_2 = SkipDiscriminatorBlock(filters = self.FILTERS[-2], first_block = True)([inp])
            for i in range(self.num_block-2, 0, -1):
                out_1, out_2 = SkipDiscriminatorBlock(filters = self.FILTERS[i])([out_1, out_2])
                
            out = MinibatchStddev()(out_2)
            out = Conv2D(filters = self.FILTERS[0], kernel_size = (3, 3), strides = (2, 2))(out)
            out = tf.keras.layers.LeakyReLU(alpha = 0.2)(out)
            
            out = tf.keras.layers.Flatten()(out)
            out = Linear(neurons = 1, gain = 1.0)(out)
            
            return tf.keras.models.Model(inp, out, name = f'skip_discriminator_{self.gen_res}x{self.gen_res}')
            
    def compile(self, optimizer, gen_type: str = 'skip', disc_type: str = 'residual'):
        super().compile()
        self.optimizer = optimizer
        self.gen_type = gen_type.lower()
        self.disc_type = disc_type.lower()
        
        self.generator = self.__init_generator()
        self.discriminator = self.__init_discriminator()
        
        self.step = 0
        
    def generator_loss(self, disc_gen_out):
        return -tf.reduce_mean(disc_gen_out)
    
    def discriminator_loss(self, disc_real_out, disc_gen_out):
        return tf.reduce_mean(disc_gen_out) - tf.reduce_mean(disc_real_out)
    
    def gradient_penalty(self, real_img, gen_img):
        
        batch_size = tf.shape(real_img)[0]
        epsilon = tf.random.uniform((batch_size, 1, 1, 1), minval = 0.0, maxval = 1.0)
        interpolated_img = ((1 - epsilon) * real_img) + (epsilon * gen_img)
        
        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated_img)
            out = self.discriminator(interpolated_img)
        
        grads = gp_tape.gradient(out, [interpolated_img])[0]
        norm = tf.square(tf.reduce_mean(tf.square(grads), keepdims = True, axis = [1, 2, 3]))
        gp = tf.reduce_mean(tf.square(norm - 1)) * self.gp_weight
        return gp
    
    def drift_loss(self, disc_real_out):
        return tf.reduce_mean(tf.square(disc_real_out)) * self.drift_weight
    

#     def path_length_penalty(self, gen_out, w, beta = 0.995):
#         y = tf.random.normal(tf.shape(gen_out))
        
#         img_size = tf.cast(gen_out.shape[1] * gen_out.shape[2], tf.float32)
#         out = tf.math.reduce_sum(gen_out * y) * tf.math.rsqrt(img_size)
        
#         grads = tf.gradients(out, w)
#         norm = tf.sqrt(tf.math.reduce_mean(tf.math.reduce_sum(tf.square(grads), axis = 2), axis = 1))
        
#         pl_mean = tf.reduce_mean(norm)
        
        
#         return tf.square(out - pl_mean)
        
    
    def generate_noise(self, batch_size):
        return [tf.random.normal((batch_size, 2**(i+2), 2**(i+2), 1)) for i in range(self.num_block)]
    
    def generate_latent(self, batch_size):
            return [tf.random.normal((batch_size, self.latent_dim)) 
                    for _ in range(self.num_block)] if self.seperate_latent else tf.random.normal((batch_size, 
                                                                                                         self.latent_dim))

        
    def train_step(self, real_img):
        if isinstance(real_img, tuple):
            real_img = real_img[0]
        batch_size = tf.shape(real_img)[0]
            
        for _ in range(self.d_steps):
            
            const_inp = tf.ones((batch_size, 4, 4, 512))
            latent_inp = self.generate_latent(batch_size)
            noise_inp = self.generate_noise(batch_size)
            gen_inp = latent_inp + [const_inp] + noise_inp if self.seperate_latent else [latent_inp, const_inp] + noise_inp
            with tf.GradientTape() as disc_tape:
                gen_out = self.generator(gen_inp, training = True)
                
                disc_real_out = self.discriminator(real_img, training = True)
                disc_gen_out = self.discriminator(gen_out, training = True)
                
                gp_loss = self.gradient_penalty(real_img, gen_out)
                drf_loss = self.drift_loss(disc_real_out)
                _disc_loss = self.discriminator_loss(disc_real_out, disc_gen_out)
                # pl = self.path_length_penalty(gen_out, latent_inp)
                disc_loss = _disc_loss + gp_loss + drf_loss # + pl
                
            disc_grads = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)
            self.optimizer.apply_gradients(zip(disc_grads, self.discriminator.trainable_variables))
            
        
        const_inp = tf.ones((batch_size, 4, 4, 512))
        latent_inp = self.generate_latent(batch_size)
        noise_inp = self.generate_noise(batch_size)
        gen_inp = latent_inp + [const_inp] + noise_inp if self.seperate_latent else [latent_inp, const_inp] + noise_inp
        with tf.GradientTape() as gen_tape:
            gen_out = self.generator(gen_inp, training = True)
            disc_gen_out = self.discriminator(gen_out, training = True)
            gen_loss = self.generator_loss(disc_gen_out)
            
        gen_grads = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
        self.optimizer.apply_gradients(zip(gen_grads, self.generator.trainable_variables))
        
        return {'disc_loss': disc_loss, 'gp_loss': gp_loss, 'drf_loss': drf_loss, '_disc_loss': _disc_loss, 'gen_loss': gen_loss}

In [19]:
path = r'E:\Image Datasets\Celeb A\Dataset\img_align_celeba'
image_size = 256
BATCH_SIZE = 1
EPOCHS = 10
STEPS_PER_EPOCH = 1000

train_data_generator = tf.keras.preprocessing.image.ImageDataGenerator(
    preprocessing_function = lambda x: tf.cast((x/127.5)-1, tf.float32))

train_data = train_data_generator.flow_from_directory(directory  = path, target_size = (image_size, image_size),
                                                 shuffle = True, batch_size = BATCH_SIZE, class_mode = 'binary')

Found 202599 images belonging to 1 classes.


In [20]:
gan = GAN(seperate_latent = True, gen_res = image_size)
gan.compile(
    optimizer = tf.keras.optimizers.Adam(),
    gen_type = 'skip', disc_type = 'residual'
)

In [21]:
gan.fit(train_data, epochs = EPOCHS, steps_per_epoch = STEPS_PER_EPOCH, batch_size = BATCH_SIZE)

Epoch 1/10
  45/1000 [>.............................] - ETA: 3:36 - disc_loss: 10.0000 - gp_loss: 10.0000 - drf_loss: 3.2238e-11 - _disc_loss: -1.4675e-07 - gen_loss: 2.1068e-05

KeyboardInterrupt: 

In [None]:
# plt.imshow((((gan.generator([tf.random.normal((1, 512)) for _ in range(7)] + [tf.ones((1, 4, 4, 512))] + 
#               [tf.random.normal((1, 4, 4, 1)), tf.random.normal((1, 8, 8, 1)), tf.random.normal((1, 16, 16, 1)), 
#                tf.random.normal((1, 32, 32, 1)), tf.random.normal((1, 64, 64, 1)), tf.random.normal((1, 128, 128, 1)), 
#                tf.random.normal((1, 256, 256, 1))])[0])+1)*127.5).numpy().astype('uint8'))

In [None]:
# gan = GAN(seperate_latent = True)
# gan.compile(optimizer=None, gen_type = 'skip')
# gan.compile(optimizer=None, gen_type = 'residual')

In [None]:
# gan.generator().summary()
# gan.discriminator().summary(line_length = 200)

In [None]:
# tf.keras.utils.plot_model(gan.generator(), show_shapes = True, dpi = 64)
# tf.keras.utils.plot_model(gan.discriminator(), show_shapes = True, dpi = 64)

In [None]:
# tf.keras.utils.plot_model(gan.generator(), show_shapes = True, dpi = 64)