# Attention GAN

This is an attempt to re-implement the paper Attention GAN

Paper: https://openaccess.thecvf.com/content_cvpr_2018/papers/Xu_AttnGAN_Fine-Grained_Text_CVPR_2018_paper.pdf

Other Resources: 
* https://github.com/taki0112/AttnGAN-Tensorflow
* https://github.com/taoxugit/AttnGAN

In [150]:
import numpy as np
import tensorflow as tf

In [3]:
class ReflectPadding2D(tf.keras.layers.Layer):
    def __init__(self, padding, **kwargs):
        super().__init__(**kwargs)
        if isinstance(padding, int):
            self.padding = (padding, padding), (padding, padding)
        elif isinstance(padding, tuple) | isinstance(padding, list):
            if isinstance(padding[0], tuple) | isinstance(padding[0], list):
                self.padding = (padding[0][0], padding[0][1]), (padding[1][0], padding[1][1])
            elif isinstance(padding[0], int):
                self.padding = (padding[0], padding[0]), (padding[1], padding[1])
            else:
                raise Exception('invalid padding')
            
        else:
            raise Exception('invalid padding')
            
    def call(self, x):
        return tf.pad(x, ((0, 0), self.padding[0], self.padding[1], (0, 0)), mode = 'REFLECT')

In [4]:
class Linear(tf.keras.layers.Layer):
    def __init__(self, neurons, **kwargs):
        super().__init__(**kwargs)
        self.neurons = neurons
    
    def build(self, input_shape):
        inp_neurons = input_shape[-1]
        
        init = tf.keras.initializers.RandomNormal(mean = 0.0, stddev = 1.0)
        self.W = self.add_weight(shape = (inp_neurons, self.neurons), initializer = init, 
                                 trainable = True, name = 'weight')
        self.B = self.add_weight(shape = (1, self.neurons), initializer = 'zeros', 
                                 trainable = True, name = 'bias')
        
    def call(self, inputs):
        return tf.add(tf.matmul(inputs, self.W), self.B)

In [5]:
class Conv2D(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, strides, padding, **kwargs):
        super().__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        
        if isinstance(padding, str):
            if padding.upper() in ['SAME', 'VALID']:
                self.padding = padding.upper()
            else:
                raise Exception('Invalid padding')
        elif isinstance(padding, int) | isinstance(padding, tuple) | isinstance(padding, list):
            self.padding = ReflectPadding2D(padding)
        
        else:
            raise Exception('Invalid padding')
        
    def build(self, input_shape):
        inp_filters = input_shape[-1]
        
        init = tf.keras.initializers.RandomNormal(mean = 0.0, stddev = 1.0)
        self.W = self.add_weight(shape = self.kernel_size + (inp_filters, self.filters), initializer = init, 
                                 trainable = True, name = 'weight')
        self.B = self.add_weight(shape = (1, 1, 1, self.filters), initializer = 'zeros', 
                                 trainable = True, name = 'bias')
        
    def call(self, inputs):
        if isinstance(self.padding, str):
            return tf.add(tf.nn.conv2d(inputs, self.W, self.strides, self.padding), self.B)
        return tf.add(tf.nn.conv2d(self.padding(inputs), self.W, self.strides, 'VALID'), self.B)

In [6]:
class GLU(tf.keras.layers.Layer):
    '''
        note: channels size will be reduced to half
    '''
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        pass
    
    def call(self, inputs):
        shp_ln = len(inputs.shape)
        assert inputs.shape[-1] % 2 == 0
        nc = inputs.shape[-1]//2
        
        x = tf.transpose(inputs, perm = [shp_ln-1] + [i for i in range(shp_ln-1)])
        x = x[:nc] * tf.nn.sigmoid(x[nc:]) # parsing through channels
        x = tf.transpose(x, perm = [i for i in range(1,shp_ln)] + [0])
        return x

In [7]:
class ImageEncoder(tf.keras.layers.Layer):
    def __init__(self, dim, pt_model = None, pt_layers = [], preprocess_func = None, **kwargs):
        super().__init__(**kwargs)
        
        self.embed_model = tf.keras.applications.InceptionV3(include_top = False, weights = 'imagenet', pooling = 'avg') if pt_model is None else pt_model
        self.embed_model.trainable = False
        
        if preprocess_func is None:
            self.preprocess_input = tf.keras.applications.inception_v3.preprocess_input
        else:
            self.preprocess_input = preprocess_func
            
        layers = ['mixed7'] if len(pt_layers) == 0 else pt_layers
        outs = [self.embed_model.get_layer(layer).output for layer in layers]
        
        self.feature_model = tf.keras.models.Model(self.embed_model.inputs, outs)
        self.feature_model.trainable = False
        
        #self.conv = Conv2D(filters = dim, kernel_size = (1, 1), strides = (1, 1), padding = (0, 0))
        self.linear_1 = Linear(neurons = dim)
        self.linear_2 = Linear(neurons = dim)
        
    def call(self, inputs):
        x = tf.image.resize((inputs+1)*127.5, [299, 299], tf.image.ResizeMethod.BILINEAR)
        x = self.preprocess_input(x)
        
        features = self.feature_model(x)
        features = tf.reshape(features, (-1, features.shape[1] * features.shape[2], features.shape[3]))
        features = self.linear_1(features)
        #features = self.conv(features)
        
        embed = self.embed_model(x)
        embed = self.linear_2(embed)
        
        return features, embed

In [8]:
class TextEncoder(tf.keras.layers.Layer):
    def __init__(self, inp_dim, dim, apply_dropout = 0.5, **kwargs):
        super().__init__(**kwargs)
        
        self.embedding = tf.keras.layers.Embedding(inp_dim, dim)
        if apply_dropout:
            self.dropout = tf.keras.layers.Dropout(apply_dropout)
            
        self.rnn = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(dim, return_state = True, return_sequences = True))
        
    def call(self, inputs):
        x = self.embedding(inputs)
        if hasattr(self, 'dropout'):
            x = self.dropout(x)
        outs = self.rnn(x)
        
        word_embed = outs[0]
        sent_embed = tf.concat([outs[1], outs[2]], axis = -1)
        
        mask = tf.math.equal(inputs, 0)
        return word_embed, sent_embed, mask

In [9]:
class ConditionalAugmentation(tf.keras.layers.Layer):
    def __init__(self, dim, **kwargs):
        super().__init__(**kwargs)
        self.dim = dim
        self.linear = Linear(dim * 2)
        self.act = tf.keras.layers.ReLU()
        
    def call(self, inputs):
        x = self.act(self.linear(inputs))
        
        mu = x[:, :self.dim]
        logvar = x[:, self.dim:]
        
        epsilon = tf.random.normal(tf.shape(mu), mean=0.0, stddev=1.0)
        out = mu + tf.exp(logvar * 0.5) * epsilon
        
        return out, mu, logvar

In [10]:
class Attention(tf.keras.layers.Layer):
    def __init__(self, dim, **kwargs):
        super().__init__(**kwargs)
        self.linear = Linear(neurons = dim)
        
    def build(self, input_shape):
        _, self.hh, self.hw, _ = input_shape[0]
        
    def call(self, inputs):
        h, e, mask = inputs
        
        h = tf.reshape(h, (-1, h.shape[1]*h.shape[2], h.shape[3]))
        context = self.linear(e) #b x seq_len x dim

        attention = tf.matmul(h, context, transpose_b = True) # b x h*w x seq_len
        attention += tf.reshape(tf.tile(tf.cast(mask, tf.float32)*-1e12, (self.hh*self.hw, 1)), 
                                (-1, self.hh*self.hw, mask.shape[-1]))
        attention = tf.nn.softmax(attention, axis = -1)

        out = tf.matmul(attention, context)
        out = tf.reshape(out, (-1, self.hh, self.hw, context.shape[-1]))
        
        attention = tf.reshape(attention, (-1, self.hh, self.hw, context.shape[1]))
        return out, attention

In [11]:
class UpSampling(tf.keras.layers.Layer):
    def __init__(self, dim, **kwargs):
        super().__init__(**kwargs)
        self.up_sample = tf.keras.layers.UpSampling2D(size = (2, 2), interpolation = 'nearest')
        self.conv = Conv2D(filters = dim * 2, kernel_size = (3, 3), strides = (1, 1), padding = (1, 1))
        self.norm = tf.keras.layers.BatchNormalization()
        self.glu = GLU() # channels gets reduced to half
        
    def call(self, inputs):
        return self.glu(self.norm(self.conv(self.up_sample(inputs))))

In [12]:
class D_ConvBlock(tf.keras.layers.Layer):
    def __init__(self, dim, kernel_size, strides, padding, norm = True, act = True, **kwargs):
        super().__init__(**kwargs)
        self.conv = Conv2D(filters = dim, kernel_size = kernel_size, strides = strides, padding = padding)
        if norm:
            self.norm = tf.keras.layers.BatchNormalization()
        if act:
            self.act = tf.keras.layers.LeakyReLU(alpha = 0.2)
        
    def call(self, inputs):
        x = self.conv(inputs)
        if hasattr(self, 'norm'):
            x = self.norm(x)
        if hasattr(self, 'act'):
            x = self.act(x)
        return x

In [13]:
class ResidualBlock(tf.keras.layers.Layer):
    def __init__(self, dim, **kwargs):
        super().__init__(**kwargs)
        self.conv_1 = Conv2D(filters = dim * 2, kernel_size = (3, 3), strides = (1, 1), padding = (1, 1))
        self.norm_1 = tf.keras.layers.BatchNormalization()
        self.act_1 = GLU()
        
        self.conv_2 = Conv2D(filters = dim, kernel_size = (3, 3), strides = (1, 1), padding = (1, 1))
        self.norm_2 = tf.keras.layers.BatchNormalization()
        
    def call(self, inputs):
        return self.norm_2(self.conv_2(self.act_1(self.norm_1(self.conv_1(inputs))))) + inputs

In [14]:
class IMAGE_GENERATOR(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.conv = Conv2D(filters = 3, kernel_size = (3, 3), strides = (1, 1), padding = (1, 1))
        self.act = tf.keras.layers.Activation('tanh')
    
    def call(self, inputs):
        return self.act(self.conv(inputs))

In [15]:
class F_INIT_STAGE(tf.keras.layers.Layer):
    '''
        This generator generates 64x64 images.
    '''
    def __init__(self, dim, up_block = 4, **kwargs):
        super().__init__(**kwargs)
        
        self.fc = tf.keras.models.Sequential([
            Linear(neurons = 2 * 4 * 4 * dim), 
            tf.keras.layers.BatchNormalization(),
            GLU(),
            tf.keras.layers.Reshape((4, 4, dim)),
        ])
        
        self.up_sample = tf.keras.models.Sequential()
        for _ in range(up_block):
            dim //= 2
            self.up_sample.add(UpSampling(dim))
            
        self.gen_img = IMAGE_GENERATOR()
        
    def call(self, inputs):
        x = self.up_sample(self.fc(inputs))
        img = self.gen_img(x)
        return x, img

In [16]:
class F_STAGE(tf.keras.layers.Layer):
    def __init__(self, dim, **kwargs):
        super().__init__(**kwargs)
        
        self.attention = Attention(dim)
        
        self.model = tf.keras.models.Sequential()
        for _ in range(2):
            self.model.add(ResidualBlock(dim * 2))
        self.model.add(UpSampling(dim))
        
        self.gen_img = IMAGE_GENERATOR()
        
    def call(self, inputs):
        h, e, mask = inputs
        
        x, _ = self.attention([h, e, mask])
        h_x = tf.concat([h, x], axis = -1)
        
        h = self.model(h_x)
        img = self.gen_img(h)
        
        return h, img

In [17]:
class DNET(tf.keras.layers.Layer):
    def __init__(self, dim, n_regular_down, n_extra_down, n_final_block, **kwargs):
        super().__init__(**kwargs)
        
        self.conditional_conv = Conv2D(filters = 1, kernel_size = (4, 4), strides = (4, 4), padding = 'valid')
        self.unconditional_conv = Conv2D(filters = 1, kernel_size = (4, 4), strides = (4, 4), padding = 'valid')
        
        self.block_1 = tf.keras.models.Sequential()
        self.block_1.add(D_ConvBlock(dim = 1, kernel_size = (4, 4), strides = (2, 2), padding = (1, 1), norm = False))
        
        for i in range(n_regular_down + n_extra_down):
            if i <= n_regular_down:
                dim *= 2
                
            self.block_1.add(D_ConvBlock(dim = dim, kernel_size = (4, 4), strides = (2, 2), padding = (1, 1)))
        dim //= 2
        
        for _ in range(n_final_block):
            self.block_1.add(D_ConvBlock(dim = dim, kernel_size = (3, 3), strides = (1, 1), padding = (1, 1)))
            
        self.block_2 = D_ConvBlock(dim = dim, kernel_size = (3, 3), strides = (1, 1), padding = (1, 1))
        
    def call(self, inputs):
        x, sent_embed = inputs
        out_1 = self.block_1(x)
        
        unconditional_out = self.unconditional_conv(out_1)
        
        out_2 = tf.concat([out_1, sent_embed], axis = -1)
        out_2 = self.block_2(out_2)
        
        conditional_out = self.conditional_conv(out_2)
        return unconditional_out, conditional_out

In [25]:
def generator(text_shape, vocab_size, dim = 32, embed_dim = 256, z_dim = 100):
    inp_text_encoder = tf.keras.layers.Input(shape = text_shape, dtype = tf.float32, name = 'text_encoder_input')
    inp_noise = tf.keras.layers.Input(shape = (z_dim, ), dtype = tf.float32, name = 'noise_z_input')
    
    
    word_embed, sent_embed, mask = TextEncoder(vocab_size, embed_dim)(inp_text_encoder)
    ca, mu, logvar = ConditionalAugmentation(z_dim)(sent_embed)
    
    ca_z = tf.keras.layers.Concatenate(axis = -1)([ca, inp_noise])
    
    h, img_0 = F_INIT_STAGE(dim = dim * 16)(ca_z)
    h, img_1 = F_STAGE(dim = dim)([h, word_embed, mask])
    _, img_2 = F_STAGE(dim = dim)([h, word_embed, mask])
    
    feature, embed = ImageEncoder(embed_dim)(img_2)
    
    inputs = [inp_text_encoder, inp_noise]
    outputs = [word_embed, sent_embed, mu, logvar, img_0, img_1, img_2, feature, embed]
    return tf.keras.models.Model(inputs, outputs)

In [21]:
def discriminator(sent_embed_shape, img_shps = [64, 128, 256], dim = 64):
    inp_img_0 = tf.keras.layers.Input(shape = (img_shps[0], img_shps[0], 3), dtype = tf.float32, 
                                      name = f'input_image_{img_shps[0]}x{img_shps[0]}x3')
    inp_img_1 = tf.keras.layers.Input(shape = (img_shps[1], img_shps[1], 3), dtype = tf.float32, 
                                      name = f'input_image_{img_shps[1]}x{img_shps[1]}x3')
    inp_img_2 = tf.keras.layers.Input(shape = (img_shps[2], img_shps[2], 3), dtype = tf.float32, 
                                      name = f'input_image_{img_shps[2]}x{img_shps[2]}x3')
    inp_sent_embed = tf.keras.layers.Input(shape = sent_embed_shape, dtype = tf.float32, name = 'input_sent_embed')
    
    
    sent_embed = tf.reshape(inp_sent_embed, (-1, 1, 1, sent_embed_shape[-1]))
    sent_embed = tf.tile(sent_embed, (1, 4, 4, 1))
    
    unconditional_out_0, conditional_out_0 = DNET(dim = dim, n_regular_down = 3, n_extra_down = 0, 
                                                  n_final_block = 0)([inp_img_0, sent_embed])
    unconditional_out_1, conditional_out_1 = DNET(dim = dim, n_regular_down = 3, n_extra_down = 1, 
                                                  n_final_block = 1)([inp_img_1, sent_embed])
    unconditional_out_2, conditional_out_2 = DNET(dim = dim, n_regular_down = 3, n_extra_down = 2, 
                                                  n_final_block = 2)([inp_img_2, sent_embed])
    
    unconditional_outs = [unconditional_out_0, unconditional_out_1, unconditional_out_2]
    conditional_outs = [conditional_out_0, conditional_out_1, conditional_out_2]
    
    return tf.keras.models.Model([inp_img_0, inp_img_1, inp_img_2, inp_sent_embed], [unconditional_outs, conditional_outs])

In [163]:
def func_attention(query, context, gamma1 = 4.0):
    
    bs, hw, dim = query.shape
    _, seq_len, _ = context.shape
    
    attn = tf.matmul(query, context, transpose_b = True)
    attn = tf.nn.softmax(attn)
    
    attn = tf.transpose(attn, perm = [0, 2, 1]) * gamma1
    attn = tf.nn.softmax(attn)
    
    out = tf.matmul(query, attn, 1, 1) # 'OR' tf.transpose(tf.matmul(attn, query), perm = [0, 2, 1])
    return out

def cosine_similarity(x, y):
    norm_x = tf.math.sqrt(tf.math.reduce_sum(tf.square(x), axis = -1, keepdims = True))
    norm_y = tf.math.sqrt(tf.math.reduce_sum(tf.square(y), axis = -1, keepdims = True))
    return tf.math.reduce_mean(x*y, axis = -1, keepdims = True)/(norm_x * norm_y)
    
    
def word_loss(img_feature, word_embed, class_id, gamma2 = 5.0):
    batch_size, seq_len, _ = word_embed.shape

    label = tf.cast(range(batch_size), tf.int32)
    masks = []
    similarities = []

    for i in range(batch_size):
        mask = (class_id.numpy() == class_id[i].numpy()).astype(np.uint8)
        mask[i] = 0
        masks.append(np.reshape(mask, (1, -1)))

        word = tf.tile(word_embed[i:i+1, :, :], multiples=[batch_size, 1, 1])

        weiContext, _ = func_attention(img_feature, word)
        weiContext = tf.transpose(weiContext, perm = [0, 2, 1])

        word = tf.reshape(word, shape=[batch_size * seq_len, -1])
        weiContext = tf.reshape(weiContext, shape = [batch_size * seq_len, -1])

        row_sim = cosine_similarity(word, weiContext)
        row_sim = tf.reshape(row_sim, shape = [batch_size, seq_len])

        row_sim = tf.exp(row_sim * gamma2)
        row_sim = tf.reduce_sum(row_sim, axis = -1, keepdims=True)
        row_sim = tf.math.log(row_sim)

        similarities.append(row_sim)

    similarities = tf.concat(similarities, axis = -1)
    masks = tf.cast(tf.concat(masks, axis = 0), tf.float32)

    similarities = similarities * gamma2
    
    similarities += tf.cast(masks, tf.float32) * -1e12
    #similarities = tf.where(tf.equal(masks, True), -1e12, similarities)

    similarities_ = tf.transpose(similarities, perm = [1, 0])

    loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(label, similarities))
    loss += tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(label, similarities_))
    return loss


def sent_loss(img_feature, sent_emb, class_id, gamma3=10.0):
    batch_size, _ = sent_emb
    
    label = tf.cast(range(batch_size), tf.int32)
    masks = []

    for i in range(batch_size):
        mask = (class_id.numpy() == class_id[i].numpy()).astype(np.uint8)
        mask[i] = 0
        masks.append(np.reshape(mask, (1, -1)))

    masks = tf.cast(tf.concat(masks, axis = 0), tf.float32)

    cnn_code = tf.expand_dims(img_feature, axis = 0)
    rnn_code = tf.expand_dims(sent_emb, axis = 0)

    cnn_code_norm = tf.norm(cnn_code, axis = -1, keepdims = True)
    rnn_code_norm = tf.norm(rnn_code, axis = -1, keepdims = True)

    scores0 = tf.matmul(cnn_code, rnn_code, transpose_b = True)
    norm0 = tf.matmul(cnn_code_norm, rnn_code_norm, transpose_b = True)
    scores0 = scores0 / tf.clip_by_value(norm0, clip_value_min = 1e-8, clip_value_max = float('inf')) * gamma3

    scores0 = tf.squeeze(scores0, axis = 0)
    scores0 += tf.cast(masks, tf.float32)*-1e12
    # scores0 = tf.where(tf.equal(masks, True), -1e12, scores0)
    
    scores1 = tf.transpose(scores0, perm = [1, 0])

    loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(label, scores0))
    loss += tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(label, scores1))
    return loss


bce_loss = tf.keras.losses.BinaryCrossentropy(from_logits = True)

def generator_loss(disc_gen_out):
    return bce_loss(tf.ones_like(disc_gen_out), disc_gen_out)

def discriminator_loss(disc_real_out, disc_gen_out):
    if disc_real_out is not None:
        loss = bce_loss(tf.ones_like(disc_real_out), disc_real_out) + bce_loss(tf.zeros_like(disc_gen_out), disc_gen_out)
        loss *= 0.5
    else:
        loss = bce_loss(tf.zeros_like(disc_gen_out), disc_gen_out)
    
    return loss 
    
def kl_loss(mu, logvar):
    loss = 0.5 * tf.reduce_sum(tf.square(mu) + tf.exp(logvar) - 1 - logvar, axis=-1)
    return tf.math.reduce_mean(loss)

In [164]:
class Trainer(object):
    def __init__(self, generator, discriminator, z_dim, learning_rate = 2e-4):
        self.z_dim = z_dim
        
        self.gen_optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate, beta_1 = 0.5, beta_2 = 0.999)
        self.disc_optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate, beta_1 = 0.5, beta_2 = 0.999)
        
        self.generator = generator
        self.discriminator = discriminator
        
    @tf.function
    def train_step(self, text, img, class_id):
        
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            sent_index = tf.random.uniform(shape = [], minval = 0, maxval = 10, dtype = tf.int32)
            txt = tf.gather(caption, sent_index, axis = 1)
            noise_z = tf.random.normal(shape = (self.batch_size, self.generator.input_shape[1][1]))
            
            
            real_0, real_1 = resize(img, target_size=[64, 64]), resize(img, target_size=[128, 128])

            word_embed, sent_embed, mu, logvar, img_0, img_1, img_2, feature, embed = self.generator([txt, noise_z], training = True)
            
            unconditional_gen_out, conditional_gen_out = self.discriminator([img_0, img_1, img_2, sent_embed], training = True)
            unconditional_real_out, conditional_real_out = self.discriminator([real_0, real_1, img, sent_embed], training = True)
            _, conditional_wrong_out = self.discriminator([real_0[:-1], real_1[:-1], img[:-1], sent_embed[1:]])
            
            # gen_loss
            gen_loss = 0
            for ugo, cgo in zip(unconditional_gen_out, conditional_gen_out):
                gen_loss += generator_loss(ugo) + generator_loss(cgo)
                
            gen_loss += word_loss(feature, word_embed, class_id)*5.0
            gen_loss += sent_loss(embed, sent_embed, class_id)*5.0
            gen_loss += kl_loss(mu, logvar)
            
            # disc_loss
            disc_loss = 0
            for i in range(3):
                disc_l = discriminator_loss(unconditional_real_out[i], unconditional_gen_out[i])
                disc_l += discriminator_loss(conditional_real_out[i], conditional_gen_out[i])
                disc_l += discriminator_loss(None, conditional_wrong_out[i])
                
                disc_loss += disc_l/3
                
                
        gen_grads = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
        self.gen_optimizer.apply(zip(gen_grads, self.generator.trainable_variables))
        
        disc_grads = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)
        self.disc_optimizer.apply(zip(disc_grads, self.discriminator.trainable_variables))
        
        return gen_loss, disc_loss
    
    def train(self, data, epochs = 1):
        gen_losses, disc_losses = [], []
        for e in range(epochs):
            print(f'Epoch: {e} Start')
            for text, img, class_id in data:
                gen_loss, disc_loss = self.train_step(text, img, class_id)
                print('.', end = '')
                
            gen_losses.append(gen_loss)
            disc_losses.append(disc_loss)
            print(f'\nGenerator Loss: {gen_loss} \t Discriminator Loss: {disc_loss}')
            print(f'Epoch: {e} Ends\n')
            
        return {'gen_losses': gen_losses, 'disc_losses': disc_losses}