# AD-GAN (Attribute Decomposition Gan)

This is an attempt to re-implement the paper AD-GAN

Paper: https://arxiv.org/pdf/2003.12267.pdf

Other Resources: 
* https://github.com/menyifang/ADGAN
* https://github.com/roimehrez/contextualLoss/blob/master/CX/CSFlow.py

In [1]:
import tensorflow as tf

In [2]:
class ReflectPadding2D(tf.keras.layers.Layer):
    def __init__(self, padding, **kwargs):
        super().__init__(**kwargs)
        if isinstance(padding, int):
            self.padding = (padding, padding), (padding, padding)
        elif isinstance(padding, tuple) | isinstance(padding, list):
            if isinstance(padding[0], tuple) | isinstance(padding[0], list):
                self.padding = (padding[0][0], padding[0][1]), (padding[1][0], padding[1][1])
            elif isinstance(padding[0], int):
                self.padding = (padding[0], padding[0]), (padding[1], padding[1])
            else:
                raise Exception('')
                
        else:
            raise Exception('')
            
                
    def call(self, inputs):
        return tf.pad(inputs, ((0, 0), self.padding[0], self.padding[1], (0, 0)), mode = 'REFLECT')

In [3]:
class Linear(tf.keras.layers.Layer):
    def __init__(self, neurons, **kwargs):
        super().__init__(**kwargs)
        self.neurons = neurons
        
    def build(self, input_shape):
        inp_neurons = input_shape[-1]
        
        init = tf.keras.initializers.RandomNormal(mean = 0.0, stddev = 0.02)
        self.weight = self.add_weight(shape = (inp_neurons, self.neurons), initializer = init, 
                                      trainable = True, name = 'weight')
        self.bias = self.add_weight(shape = (1, self.neurons), initializer = 'zeros', 
                                    trainable = True, name = 'bias')
        
    def call(self, inputs):
        return tf.add(tf.matmul(inputs, self.weight), self.bias)

In [4]:
class Conv2D(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, strides, padding, **kwargs):
        super().__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        if isinstance(padding, str):
            self.padding = padding.upper()
        else:
            self.padding = ReflectPadding2D(padding)
        
    def build(self, input_shape):
        inp_filters = input_shape[-1]
        
        init = tf.keras.initializers.RandomNormal(mean = 0.0, stddev = 0.02)
        self.weight = self.add_weight(shape = self.kernel_size + (inp_filters, self.filters), initializer = init, 
                                      trainable = True, name = 'weight')
        self.bias = self.add_weight(shape = (1, 1, 1, self.filters), initializer = 'zeros', 
                                    trainable = True, name = 'bias')
        
    def call(self, inputs):
        if isinstance(self.padding, str):
            return tf.add(tf.nn.conv2d(inputs, self.weight, self.strides, self.padding), self.bias)
        return tf.add(tf.nn.conv2d(self.padding(inputs), self.weight, self.strides, 'VALID'), self.bias)

In [5]:
class AdaIn(tf.keras.layers.Layer):
    def __init__(self, epsilon = 1e-8, **kwargs):
        super().__init__(**kwargs)
        self.epsilon = epsilon
        
    def build(self, input_shape):
        self.inp_filters = input_shape[0][-1]
        
        self.linear_1 = Linear(neurons = self.inp_filters)
        self.linear_2 = Linear(neurons = self.inp_filters)
        
    def call(self, inputs):
        x, s = inputs
        
        mean_x = tf.math.reduce_mean(x, axis = [1, 2], keepdims = True)
        rstd_x = tf.math.rsqrt(tf.math.reduce_variance(x) + self.epsilon)
        inst_norm = (x - mean_x) * rstd_x
        
        ys = tf.reshape(self.linear_1(s), (-1, 1, 1, self.inp_filters))
        yb = tf.reshape(self.linear_2(s), (-1, 1, 1, self.inp_filters))
        
        out = ys * inst_norm + yb
        return out

In [6]:
class InstanceNormalization(tf.keras.layers.Layer):
    def __init__(self, epsilon = 1e-8, **kwargs):
        super().__init__(**kwargs)
        self.epsilon = epsilon
        
    def build(self, input_shape):
        inp_chn = input_shape[-1]
        
        init = tf.keras.initializers.RandomNormal(mean = 0.0, stddev = 0.02)
        self.gamma = self.add_weight(shape = (1, 1, 1, inp_chn), initializer = init, 
                                     trainable = True, name = 'gamma')
        self.beta = self.add_weight(shape = (1, 1, 1, inp_chn), initializer = 'zeros', 
                                    trainable = True, name = 'beta')
        
    def call(self, inputs):
        mean = tf.math.reduce_mean(inputs, axis = [1, 2], keepdims = True)
        rstd = tf.math.rsqrt(tf.math.reduce_variance(inputs) + self.epsilon)
        inst_norm = (inputs - mean) * rstd
        return inst_norm

In [7]:
class Normalization(tf.keras.layers.Layer):
    def __init__(self, norm, **kwargs):
        super().__init__(**kwargs)
        self.norm = norm
        if norm == 'inst_norm':
            self.normalize = InstanceNormalization()
        elif norm == 'adain':
            self.normalize = AdaIn()
        elif norm == 'batch_norm':
            self.normalize = tf.keras.layers.BatchNormalization()
        elif norm == 'layer_norm':
            self.normalize = tf.keras.layers.LayerNormalization()
        else:
            self.normalize = norm
            
    def call(self, inputs):
        if self.norm == 'adain':
            assert isinstance(inputs, list) | isinstance(inputs, tuple)
            assert len(inputs) == 2
        return self.normalize(inputs)
    
class Activation(tf.keras.layers.Layer):
    def __init__(self, activation, **kwargs):
        super().__init__(**kwargs)
        if activation == 'relu':
            self.act = tf.keras.layers.ReLU()
        elif activation == 'leaky_relu':
            self.act = tf.keras.layers.LeakyReLU(alpha = 0.2)
        elif activation == 'tanh':
            self.act = tf.keras.layers.Activation('tanh')
        else:
            self.act = activation
            
    def call(self, inputs):
        return self.act(inputs)

In [8]:
class LinearBlock(tf.keras.layers.Layer):
    def __init__(self, neurons, norm, activation, **kwargs):
        super().__init__(**kwargs)
        self.linear = Linear(neurons = neurons)
        if norm is not None:
            self.normalize = Normalization(norm = norm)
        if activation is not None:
            self.activation = Activation(activation = activation)
        
    def call(self, inputs):
        out = self.linear(inputs)
        if hasattr(self, 'normalize'):
            out = self.normalize(out)
        if hasattr(self, 'activation'):
            out = self.activation(out)
        return out

class Conv2DBlock(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, strides, padding, norm, activation, **kwargs):
        super().__init__(**kwargs)
        self.norm = norm
        self.conv = Conv2D(filters = filters, kernel_size = kernel_size, strides = strides, padding = padding) 
        
        if norm is not None:
            self.normalize = Normalization(norm = norm)
        if activation is not None:
            self.activation = Activation(activation = activation)
        
    def call(self, inputs):
        if self.norm == 'adain':
            assert len(inputs) == 2
            x, s = inputs
        else:
            x = inputs
            
        out = self.conv(x)
        if hasattr(self, 'normalize'):
            out = self.normalize([out, s]) if self.norm == 'adain' else self.normalize(out)
        if hasattr(self, 'activation'):
            out = self.activation(out)
        return out

In [9]:
class FusionModule(tf.keras.layers.Layer):
    def __init__(self, dim = 256, n_layers = 3, **kwargs):
        super().__init__(**kwargs)
        
        layers = []
        for i in range(n_layers):
            layers.append(Linear(neurons = dim))
            if i != (n_layers-1):
                layers.append(Activation(activation = 'relu'))
        
        self.model = tf.keras.models.Sequential(layers)
        
    def call(self, inputs):
        return self.model(inputs)

In [10]:
class ResidualStyleBlock(tf.keras.layers.Layer):
    def __init__(self, filters, norm, fusion_dim = 256, fusion_layers = 3, **kwargs):
        super().__init__(**kwargs)
        self.norm = norm
        if norm == 'adain':
            self.fm = FusionModule(fusion_dim, fusion_layers)
            
        self.conv_block_1 = Conv2DBlock(filters = filters, kernel_size = (3, 3), strides = (1, 1), padding = (1, 1), 
                                        norm = norm, activation = 'relu')
        self.conv_block_2 = Conv2DBlock(filters = filters, kernel_size = (3, 3), strides = (1, 1), padding = (1, 1), 
                                        norm = norm, activation = None)

    def call(self, inputs):
        if self.norm == 'adain':
            assert isinstance(inputs, list) | isinstance(inputs, tuple)
            assert len(inputs) == 2
            x, s = inputs
            s = self.fm(s)
            
            out = self.conv_block_1([x, s])
            out = self.conv_block_2([out, s])
        else:
            x = inputs
            out = self.conv_block_1(x)
            out = self.conv_block_2(out)
            
        return out + x

In [11]:
class TextureEncoder(tf.keras.layers.Layer):
    def __init__(self, dim, pt_model = None, pt_layers = None, **kwargs):
        super().__init__(**kwargs)
        model = tf.keras.applications.VGG19(include_top = False, weights = 'imagenet') if pt_model is None else pt_model
        model.trainable = False
        pt_layers = ['block1_conv1', 'block2_conv1', 'block3_conv1', 'block4_conv1'] if pt_layers is None else pt_layers
        
        outs = [model.get_layer(layer).output for layer in pt_layers]
        self.model = tf.keras.models.Model(model.inputs, outs)
        
        self.conv1 = Conv2DBlock(filters = dim, kernel_size = (7, 7), strides = (1, 1), padding = (3, 3), 
                                 norm = None, activation = 'relu')
        
        self.conv2 = Conv2DBlock(filters = dim*2, kernel_size = (4, 4), strides = (2, 2), padding = (1, 1), 
                                 norm = None, activation = 'relu')
        
        self.conv3 = Conv2DBlock(filters = dim*4, kernel_size = (4, 4), strides = (2, 2), padding = (1, 1), 
                                 norm = None, activation = 'relu')
        
        self.conv4 = Conv2DBlock(filters = dim*8, kernel_size = (4, 4), strides = (2, 2), padding = (1, 1), 
                                 norm = None, activation = 'relu')
        
        self.pool = tf.keras.layers.AveragePooling2D()
        self.conv = Conv2D(filters = dim, kernel_size = (1, 1), strides = (1, 1), padding = (0, 0))

    def pt_model_out(self, inputs):
        # preprocess after converting image pixels from 0-255.
        preprocess_input = tf.keras.applications.vgg19.preprocess_input((inputs + 1)*127.5)
        return self.model(preprocess_input)
    
    def learnable_encoder(self, inputs, pt_outs):
        x = tf.concat([self.conv1(inputs), pt_outs.pop()], axis = -1)
        x = tf.concat([self.conv2(x), pt_outs.pop()], axis = -1)
        x = tf.concat([self.conv3(x), pt_outs.pop()], axis = -1)
        x = tf.concat([self.conv4(x), pt_outs.pop()], axis = -1)
        x = self.conv(self.pool(x))
        return x
        
    def call(self, inputs):
        pt_out = self.pt_model_out(inputs)
        return self.learnable_encoder(inputs, pt_out[::-1])

In [12]:
class DecomposedComponentEncoding(tf.keras.layers.Layer):
    def __init__(self, style_dim = 512, **kwargs):
        super().__init__(**kwargs)
        self.style_dim = style_dim
        self.texture_encoder = TextureEncoder(dim = style_dim//8)
        self.linear = LinearBlock(neurons = style_dim, norm = None, activation = 'relu')
        
    def split_img_masks(self, src, semantics):
        seg_labels = semantics.shape[-1]
        src = tf.expand_dims(src, axis = -1)
        semantics = tf.expand_dims(semantics, axis = -2)
        
        attributes = tf.split(src * semantics, seg_labels, axis = -1)
        attributes = [attr[:, :, :, :, 0] for attr in attributes]
            
        return attributes
            
    def call(self, inputs):
        assert len(inputs) == 2
        # input_0_shape = (b, h, w, 3), input_1_shape = (b, h, w, seg_labels)
        src, semantics = inputs
        attributes = self.split_img_masks(src, semantics)
        
        outs = [] 
        for attr in attributes:
            outs.append(self.texture_encoder(attr))
            
        outs = tf.concat(outs, axis = -1)
        outs = tf.reshape(outs, (-1, outs.shape[1] * outs.shape[2] * outs.shape[3]))
        outs = self.linear(outs)
        return outs

In [13]:
class DownSamplingBlock(tf.keras.layers.Layer):
    def __init__(self, filters, norm, activation, **kwargs):
        super().__init__(**kwargs)
        self.conv_block = Conv2DBlock(filters = filters, kernel_size = (4, 4), strides = (2, 2), padding = (1, 1), 
                                      norm = norm, activation = activation)
        
    def call(self, inputs):
        return self.conv_block(inputs)
    
class UpSamplingBlock(tf.keras.layers.Layer):
    def __init__(self, filters, norm, activation, **kwargs):
        super().__init__(**kwargs)
        self.up_sample = tf.keras.layers.UpSampling2D(size = (2, 2), interpolation = 'nearest')
        self.conv_block = Conv2DBlock(filters = filters, kernel_size = (5, 5), strides = (1, 1), padding = (2, 2), 
                                      norm = norm, activation = activation)
        
    def call(self, inputs):
        return self.conv_block(self.up_sample(inputs))

In [14]:
def generator(img_shape, img_pose_shape, img_semantic_shape, dim = 64, n_up_down = 2, n_res = 8):
    inp_pose = tf.keras.layers.Input(shape = img_pose_shape, dtype = tf.float32, name = 'pose_input') # b, h, w, points
    inp_img = tf.keras.layers.Input(shape = img_pose_shape, dtype = tf.float32, name = 'image_input') # b, h, w, 3
    inp_semantic = tf.keras.layers.Input(shape = img_semantic_shape, dtype = tf.float32, name = 'semantic_input')#b,h,w,lbl
    
    ### Encoder
    x = Conv2DBlock(filters = dim, kernel_size = (7, 7), strides = (1, 1), padding = (3, 3), norm = 'inst_norm', 
                    activation = 'relu')(inp_pose)
    
    for _ in range(n_up_down):
        dim *= 2
        x = DownSamplingBlock(filters = dim, norm = 'inst_norm', activation = 'relu')(x)
        
    for _ in range(n_res):
        x = ResidualStyleBlock(filters = dim, norm = 'inst_norm')(x)
        
    ######
    
    ### DCE (Decomposed Component Encoding)
    dce = DecomposedComponentEncoding()([inp_img, inp_semantic])
    #####
    

    ### Decoder
    for _ in range(n_res):
        x = ResidualStyleBlock(filters = dim, norm = 'adain')([x, dce])

    for _ in range(n_up_down):
        dim //= 2
        x = UpSamplingBlock(filters = dim, norm = 'layer_norm', activation = 'relu')(x)
        
    x = Conv2DBlock(filters = 3, kernel_size = (7, 7), strides = (1, 1), padding = (3, 3), norm = None, 
                    activation = 'tanh')(x)
    #####
    
    return tf.keras.models.Model([inp_pose, inp_img, inp_semantic], x, name = 'Generator')

In [17]:
class ResidualBlock(tf.keras.layers.Layer):
    def __init__(self, filters = None, norm = 'batch_norm', **kwargs):
        super().__init__(**kwargs)
        self.filters = filters
        self.norm = norm
    
    def build(self, input_shape):
        filters = input_shape[-1] if self.filters is None else self.filters
        
        self.conv_1 = Conv2DBlock(filters = filters, kernel_size = (3, 3), strides = (1, 1), padding = (1, 1), 
                                  norm = self.norm, activation = 'relu')
        self.conv_2 = Conv2DBlock(filters = filters, kernel_size = (3, 3), strides = (1, 1), padding = (1, 1), 
                                  norm = self.norm, activation = None)
        
    def call(self, inputs):
        return self.conv_2(self.conv_1(inputs)) + inputs

In [18]:
def discriminator(inpA_shape, inpB_shape, dim = 64, n_down = 2, n_blocks = 6, logits = False):
    inpA = tf.keras.layers.Input(shape = inpA_shape, dtype = tf.float32, name = 'inputA')
    inpB = tf.keras.layers.Input(shape = inpB_shape, dtype = tf.float32, name = 'inputB')
    
    x = tf.keras.layers.Concatenate()([inpA, inpB])
    
    x = Conv2DBlock(filters = dim, kernel_size = (7, 7), strides = (1, 1), padding = (0, 0), norm = 'batch_norm', 
                    activation = 'relu')(x)
    
    for i in range(n_down):
        if i < 2:
            dim *= 2
        x = Conv2DBlock(filters = dim, kernel_size = (3, 3), strides = (2, 2), padding = (1, 1), norm = 'batch_norm', 
                        activation = 'relu')(x)
        
    
    for _ in range(n_blocks):
        x = ResidualBlock(filters = dim, norm = 'batch_norm')(x)
        
    if logits:
        x = tf.keras.layers.Activation('sigmoid')(x)
        
    return tf.keras.models.Model([inpA, inpB], x, name = 'Discriminator')

In [142]:
class PerceptualLoss(object):
    def __init__(self, p_model = None, p_layers = None, loss_type = 'l1'):
        self.loss_type = loss_type
        
        model = tf.keras.applications.VGG19(include_top = False, weights = 'imagenet') if p_model is None else p_model
        model.trainable = False
        layers = ['block1_conv1', 'block2_conv1', 'block3_conv1', 'block4_conv1', 'block5_conv1'] if p_layers is None else p_layers
        
        outs = [model.get_layer(layer).output for layer in layers]
        self.model = tf.keras.models.Model(model.inputs, outs)
        
        self.preprocess_input = lambda x: tf.keras.applications.vgg19.preprocess_input(x)
        
    def __call__(self, real, gen):
        # preprocess after converting image pixels from 0-255.
        preprocess_real = self.preprocess_input((real + 1)*127.5)
        preprocess_gen = self.preprocess_input((gen + 1)*127.5)
        
        real_outs = self.model(preprocess_real)
        gen_outs = self.model(preprocess_gen)
        
        loss = 0
        for r, g in zip(real_outs, gen_outs):
            if self.loss_type == 'l1':
                loss += tf.math.reduce_mean(tf.math.abs(r - g))
            elif self.loss_type == 'l2':
                loss += tf.math.reduce_mean(tf.math.square(r - g))
            else:
                loss += self.loss_type(r, g)
            
        return loss
    
    
# References:
# https://github.com/menyifang/ADGAN/blob/master/losses/CX_style_loss.py
# https://github.com/roimehrez/contextualLoss/blob/master/CX/CSFlow.py
class ContextualLoss(object):
    def __init__(self, b = 1.0, sigma = 0.1, **kwargs):
        super().__init__(self, **kwargs)
        self.b = b
        self.sigma = sigma

    def __call__(self, real, gen):
        ### centering by real
        mean_r = tf.math.reduce_mean(real, axis = [0, 1, 2], keepdims = True)
        real, gen = real - mean_r, gen - mean_r
        
        ### l2 normalization
        real = real * tf.math.rsqrt(tf.math.reduce_mean(tf.square(real), axis = -1, keepdims = True))
        gen = gen * tf.math.rsqrt(tf.math.reduce_mean(tf.square(gen), axis = -1, keepdims = True))
        
        ### patch decomposition
        w = tf.transpose(tf.reshape(real, (-1, 1, 1, real.shape[-1])), perm = (1, 2, 3, 0))
        i = tf.expand_dims(tf.concat(list(gen), axis = -1), axis = 0)
        # cosine similarity calculation
        out = tf.nn.conv2d(i, w, (1, 1), 'VALID')
        
        assert len(list(real)) == len(list(gen))
        out = tf.concat(tf.split(out, len(list(real)), axis = -1), axis = 0)
        
        ### calculate relative distance
        raw_dist = (1.0 - out)/2.0
        d = tf.math.reduce_min(raw_dist, axis = -1, keepdims = True)[0]
        relatice_dist = raw_dist/(d + 1e-5)
        
        ### calculate CX
        cx_exp = tf.math.exp((self.b - relatice_dist) / self.sigma)
        cx_sum = tf.math.reduce_sum(cx_exp, axis = -1, keepdims = True)
        cx = cx_exp / cx_sum
        
        cx = tf.math.reduce_max(cx, axis = [1, 2])
        cx = tf.math.reduce_mean(cx, axis = 1)
        cx = -tf.log(cx)
        cx = tf.math.reduce_mean(cx)
        return cx
    
class ReconstructionLoss(object):
    def __init__(self, p = 1):
        if p == 1:
            self.loss = lambda x, y: tf.math.reduce_mean(tf.math.abs(x - y))
        elif p == 2:
            self.loss = lambda x, y: tf.math.reduce_mean(tf.math.square(x - y))
        
    def __call__(self, real, gen):
        return self.loss(real, gen)
    
class GANLoss(object):
    def __init__(self, loss_type = 'lsgan', from_logits = True):
        if loss_type == 'lsgan':
            self.loss_object = tf.keras.losses.MeanSquaredError()
        elif loss_type == 'adversarial':
            self.loss_object = tf.keras.losses.BinaryCrossentropy(from_logits = from_logits)
            
    def discriminator_loss(self, disc_real_out, disc_gen_out):
        real_loss = self.loss_object(tf.ones_like(disc_real_out), disc_real_out)
        gen_loss = self.loss_object(tf.zeros_like(disc_gen_out), disc_gen_out)
        return real_loss + gen_loss
    
    def generator_loss(self, disc_gen_out):
        return self.loss_object(tf.ones_like(disc_gen_out), disc_gen_out)

In [143]:
class Trainer(object):
    def __init__(self, img_shape, img_pose_shape, img_semantic_shape, learning_rate = 0.001, loss_type = 'lsgan', 
                 from_logits = True):
        self.gen_optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate, beta_1 = 0.5, beta_2 = 0.999)
        self.disc_optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate, beta_1 = 0.5, beta_2 = 0.999)
        
        self.recon_loss = ReconstructionLoss()
        self.cx_loss = ContextualLoss()
        self.perceptual_loss = PerceptualLoss()
        self.gan_loss = GANLoss(loss_type = loss_type, from_logits = from_logits)
        
        self.generator = generator(img_shape, img_pose_shape, img_semantic_shape)
        self.discriminator_P = discriminator(img_pose_shape, img_shape)
        self.discriminator_T = discriminator(img_shape, img_shape)
        
    def discriminator_loss(self, disc_real_out_P, disc_gen_out_P, disc_real_out_T, disc_gen_out_T):
        loss = self.gan_loss.discriminator_loss(disc_real_out_P, disc_gen_out_P)
        loss += self.gan_loss.discriminator_loss(disc_real_out_T, disc_gen_out_T)
        return loss
    
    def generator_loss(self, real, gen, disc_gen_out_P, disc_gen_out_T, lambda_rec = 2.0, lambda_per = 2.0, 
                       lambda_cx = 0.02):
        loss = self.recon_loss(real, gen) * lambda_rec
        loss += self.perceptual_loss(real, gen) * lambda_per
        loss += self.cx_loss(real, gen) * lambda_cx
        loss += self.gan_loss.generator_loss(disc_gen_out_P) + self.gan_loss.generator_loss(disc_gen_out_T)
        return loss
    
    def train_step(self, src_img, pose_img, semantic_img, target_img):
        
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            gen_out = self.generator([pose_img, src_img, semantic_img], training = True)
            
            disc_real_out_P = self.discriminator_P([pose_img, target_img], training = True)
            disc_gen_out_P = self.discriminator_P([pose_img, gen_out], training = True)
            disc_real_out_T = self.discriminator_T([src_img, target_img], training = True)
            disc_gen_out_T = self.discriminator_T([src_img, gen_out], training = True)
            
            disc_loss = self.discriminator_loss(disc_real_out_P, disc_gen_out_P, disc_real_out_T, disc_gen_out_T)
            gen_loss = self.generator_loss(target_img, gen_out, disc_gen_out_P, disc_gen_out_T)
            
        gen_grads = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
        self.gen_optimizer.apply_gradients(zip(gen_grads, self.generator.trainable_variables))
        
        disc_params = self.discriminator_P.trainable_variables + self.discriminator_T.trainable_variables
        disc_grads = disc_tape.gradient(disc_loss, disc_params)
        self.disc_optimizer.apply_gradients(zip(disc_grads, disc_params))

        return gen_loss, disc_loss
        
    def train(self, data, epochs = 1):
        gen_losses, disc_losses = [], []
        for e in range(epochs):
            print(f'Epoch: {e} Starts.')
            for src_img, pose_img, semantic_img, target_img in data:
                gen_loss, disc_loss = self.train_step(src_img, pose_img, semantic_img, target_img)
                print('.', end='')
                
            gen_losses.append(gen_loss)
            disc_losses.append(disc_loss)
            print(f'\nGenerator Loss: {gen_loss} \t Discriminator Loss: {disc_loss}')
            print(f'Epoch: {e} Ends.\n')
            
        return {'gen_losses': gen_losses, 'disc_losses': disc_losses}