In [16]:
from enum import Enum
import os

class Execution(Enum):
    Local = "Local"
    Colab = "Colab"

if "F:" in os.getcwd():
    execution = Execution.Local
else:
    execution = Execution.Colab

if execution == Execution.Colab:
    from google.colab import drive
    drive.mount("/content/gdrive")

!pip install wandb -qU
import wandb
wandb.login()


^C




True


#Hyperparams

In [None]:
hyperparams= {
            "epochs": 100,
            "batch_size": 128,
            "image_size": 64,
            "latent_dim": 100,
            "dataset_labels": 7,
            "clip_value": 0.01,
            "lr": 0.001,
            "dropout": 0.5,
            "beta_1": 0.9,
            "beta_2": 0.99,
            "n_critic": 5,
            "grad_weight": 10,
            "execution": execution
            }

wandb.init(
    project="WCGAN",
    config=hyperparams)

# Copy your config
config = wandb.config

#Custom layers

In [None]:
from keras.layers import Layer, InputSpec

try:
    from keras import initializations
except ImportError:
    from keras import initializers as initializations
import keras.backend as K


class Scale(Layer):
    '''Custom Layer for DenseNet used for BatchNormalization.

    Learns a set of weights and biases used for scaling the input data.
    the output consists simply in an element-wise multiplication of the input
    and a sum of a set of constants:

        out = in * gamma + beta,

    where 'gamma' and 'beta' are the weights and biases larned.

    # Arguments
        axis: integer, axis along which to normalize in mode 0. For instance,
            if your input tensor has shape (samples, channels, rows, cols),
            set axis to 1 to normalize per feature map (channels axis).
        momentum: momentum in the computation of the
            exponential average of the mean and standard deviation
            of the data, for feature-wise normalization.
        weights: Initialization weights.
            List of 2 Numpy arrays, with shapes:
            `[(input_shape,), (input_shape,)]`
        beta_init: name of initialization function for shift parameter
            (see [initializations](../initializations.md)), or alternatively,
            Theano/TensorFlow function to use for weights initialization.
            This parameter is only relevant if you don't pass a `weights` argument.
        gamma_init: name of initialization function for scale parameter (see
            [initializations](../initializations.md)), or alternatively,
            Theano/TensorFlow function to use for weights initialization.
            This parameter is only relevant if you don't pass a `weights` argument.
    '''

    def __init__(self, weights=None, axis=-1, momentum=0.9, beta_init='zero', gamma_init='one', **kwargs):
        self.momentum = momentum
        self.axis = axis
        self.beta_init = initializations.get(beta_init)
        self.gamma_init = initializations.get(gamma_init)
        self.initial_weights = weights
        super(Scale, self).__init__(**kwargs)

    def build(self, input_shape):
        self.input_spec = [InputSpec(shape=input_shape)]
        shape = (int(input_shape[self.axis]),)

        # Tensorflow >= 1.0.0 compatibility
        self.gamma = K.variable(self.gamma_init(shape), name='{}_gamma'.format(self.name))
        self.beta = K.variable(self.beta_init(shape), name='{}_beta'.format(self.name))
        # self.gamma = self.gamma_init(shape, name='{}_gamma'.format(self.name))
        # self.beta = self.beta_init(shape, name='{}_beta'.format(self.name))
        self._trainable_weights = [self.gamma, self.beta]

        if self.initial_weights is not None:
            self.set_weights(self.initial_weights)
            del self.initial_weights

    def call(self, x, mask=None):
        input_shape = self.input_spec[0].shape
        broadcast_shape = [1] * len(input_shape)
        broadcast_shape[self.axis] = input_shape[self.axis]

        out = K.reshape(self.gamma, broadcast_shape) * x + K.reshape(self.beta, broadcast_shape)
        return out

    def get_config(self):
        config = {"momentum": self.momentum, "axis": self.axis}
        base_config = super(Scale, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

#DenseNet

In [None]:
import keras.backend as K
from keras.layers import LayerNormalization
from keras.layers import ZeroPadding2D, Concatenate
from keras.layers.convolutional import Convolution2D
from keras.layers.core import Dense, Dropout, Activation
from keras.layers.pooling import AveragePooling2D, GlobalAveragePooling2D, MaxPooling2D
from keras.models import Model

def DenseNet(img_input_layer, img_n_label_input_layer, label_input_layer, nb_dense_block=4, growth_rate=32, nb_filter=64, reduction=0.0, dropout_rate=0.0, weight_decay=1e-4,
             classes=1000, weights_path=None):
    """Instantiate the DenseNet 121 architecture,
        # Arguments
            nb_dense_block: number of dense blocks to add to end
            growth_rate: number of filters to add per dense block
            nb_filter: initial number of filters
            reduction: reduction factor of transition blocks.
            dropout_rate: dropout rate
            weight_decay: weight decay factor
            classes: optional number of classes to classify images
            weights_path: path to pre-trained weights
        # Returns
            A Keras model instance.
    """
    eps = 1.1e-5

    # compute compression factor
    compression = 1.0 - reduction

    # Handle Dimension Ordering for different backends
    global concat_axis
    if K.image_data_format() == 'channels_last': # Tensorflow
        concat_axis = 3
        # img_input = Input(shape=(224, 224, 3), name='data')
        img_input = img_n_label_input_layer
    else: # Theano
        concat_axis = 1
        # img_input = Input(shape=(3, 224, 224), name='data')
        img_input = img_n_label_input_layer

    # From architecture for ImageNet (Table 1 in the paper)
    nb_filter = 64
    nb_layers = [6, 12, 24, 16]  # For DenseNet-121

    # Initial convolution
    x = ZeroPadding2D((3, 3), name='conv1_zeropadding')(img_input)
    x = Convolution2D(nb_filter, (7, 7), strides=(2, 2), name='conv1', use_bias=False)(x)
    x = LayerNormalization(epsilon=eps, axis=concat_axis, name='conv1_bn')(x)
    x = Scale(axis=concat_axis, name='conv1_scale')(x)
    x = Activation('relu', name='relu1')(x)
    x = ZeroPadding2D((1, 1), name='pool1_zeropadding')(x)
    x = MaxPooling2D((3, 3), strides=(2, 2), name='pool1')(x)

    # Add dense blocks
    for block_idx in range(nb_dense_block - 1):
        stage = block_idx + 2
        x, nb_filter = dense_block(x, stage, nb_layers[block_idx], nb_filter, growth_rate, dropout_rate=dropout_rate,
                                   weight_decay=weight_decay)

        # Add transition_block
        x = transition_block(x, stage, nb_filter, compression=compression, dropout_rate=dropout_rate,
                             weight_decay=weight_decay)
        nb_filter = int(nb_filter * compression)

    final_stage = stage + 1
    x, nb_filter = dense_block(x, final_stage, nb_layers[-1], nb_filter, growth_rate, dropout_rate=dropout_rate,
                               weight_decay=weight_decay)

    x = LayerNormalization(epsilon=eps, axis=concat_axis, name='conv' + str(final_stage) + '_blk_bn')(x)
    x = Scale(axis=concat_axis, name='conv' + str(final_stage) + '_blk_scale')(x)
    x = Activation('relu', name='relu' + str(final_stage) + '_blk')(x)
    x = GlobalAveragePooling2D(name='pool' + str(final_stage))(x)

    x = Dense(classes, name='fc6')(x)
    x = Activation('sigmoid', name='prob')(x)

    model = Model(inputs=[img_input_layer, label_input_layer], outputs=x, name='densenet')

    if weights_path is not None:
        model.load_weights(weights_path)

    return model


def conv_block(x, stage, branch, nb_filter, dropout_rate=None, weight_decay=1e-4):
    '''Apply BatchNorm, Relu, bottleneck 1x1 Conv2D, 3x3 Conv2D, and option dropout
        # Arguments
            x: input tensor
            stage: index for dense block
            branch: layer index within each dense block
            nb_filter: number of filters
            dropout_rate: dropout rate
            weight_decay: weight decay factor
    '''
    eps = 1.1e-5
    conv_name_base = 'conv' + str(stage) + '_' + str(branch)
    relu_name_base = 'relu' + str(stage) + '_' + str(branch)

    # 1x1 Convolution (Bottleneck layer)
    inter_channel = nb_filter * 4
    x = LayerNormalization(epsilon=eps, axis=concat_axis, name=conv_name_base + '_x1_bn')(x)
    x = Scale(axis=concat_axis, name=conv_name_base + '_x1_scale')(x)
    x = Activation('relu', name=relu_name_base + '_x1')(x)
    x = Convolution2D(inter_channel, (1, 1), name=conv_name_base + '_x1', use_bias=False)(x)

    if dropout_rate:
        x = Dropout(dropout_rate)(x)

    # 3x3 Convolution
    x = LayerNormalization(epsilon=eps, axis=concat_axis, name=conv_name_base + '_x2_bn')(x)
    x = Scale(axis=concat_axis, name=conv_name_base + '_x2_scale')(x)
    x = Activation('relu', name=relu_name_base + '_x2')(x)
    x = ZeroPadding2D((1, 1), name=conv_name_base + '_x2_zeropadding')(x)
    x = Convolution2D(nb_filter, (3, 3), name=conv_name_base + '_x2', use_bias=False)(x)

    if dropout_rate:
        x = Dropout(dropout_rate)(x)

    return x


def transition_block(x, stage, nb_filter, compression=1.0, dropout_rate=None, weight_decay=1E-4):
    ''' Apply BatchNorm, 1x1 Convolution, averagePooling, optional compression, dropout
        # Arguments
            x: input tensor
            stage: index for dense block
            nb_filter: number of filters
            compression: calculated as 1 - reduction. Reduces the number of feature maps in the transition block.
            dropout_rate: dropout rate
            weight_decay: weight decay factor
    '''

    eps = 1.1e-5
    conv_name_base = 'conv' + str(stage) + '_blk'
    relu_name_base = 'relu' + str(stage) + '_blk'
    pool_name_base = 'pool' + str(stage)

    x = LayerNormalization(epsilon=eps, axis=concat_axis, name=conv_name_base + '_bn')(x)
    x = Scale(axis=concat_axis, name=conv_name_base + '_scale')(x)
    x = Activation('relu', name=relu_name_base)(x)
    x = Convolution2D(int(nb_filter * compression), (1, 1), name=conv_name_base, use_bias=False)(x)

    if dropout_rate:
        x = Dropout(dropout_rate)(x)

    x = AveragePooling2D((2, 2), strides=(2, 2), name=pool_name_base)(x)

    return x


def dense_block(x, stage, nb_layers, nb_filter, growth_rate, dropout_rate=None, weight_decay=1e-4,
                grow_nb_filters=True):
    ''' Build a dense_block where the output of each conv_block is fed to subsequent ones
        # Arguments
            x: input tensor
            stage: index for dense block
            nb_layers: the number of layers of conv_block to append to the model.
            nb_filter: number of filters
            growth_rate: growth rate
            dropout_rate: dropout rate
            weight_decay: weight decay factor
            grow_nb_filters: flag to decide to allow number of filters to grow
    '''

    eps = 1.1e-5
    concat_feat = x

    for i in range(nb_layers):
        branch = i + 1
        x = conv_block(concat_feat, stage, branch, growth_rate, dropout_rate, weight_decay)
        concat_feat = Concatenate()([concat_feat, x])

        if grow_nb_filters:
            nb_filter += growth_rate

    return concat_feat, nb_filter

#Utils

In [None]:
import os

import numpy
import numpy as np
from keras.datasets.fashion_mnist import load_data
from keras_preprocessing.image import ImageDataGenerator
from numpy import expand_dims
from numpy.random import randint
from numpy.random import randn
from datetime import datetime


# Define datagen. Here we can define any transformations we want to apply to images
datagen = ImageDataGenerator()

# define training directory that contains subfolders
if execution == Execution.Colab:
    train_dir = os.path.join("gdrive", "My Drive", "Datasets", "HAM10000", "reorganized")
    save_dir_root = os.path.join("gdrive", "My Drive", "models")
elif execution == Execution.Local:
    train_dir = os.path.join("F:", os.sep, "backup", "Facultad", "Tesis", "DL", "datasets", "HAM10000", "data", "reorganized")
    save_dir_root = os.path.join("F:", os.sep, "backup", "Facultad", "Tesis", "DL", "U-Net CGAN", "models")

def create_save_dir():
    run_date = datetime.today().strftime('%Y-%m-%d %H-%M')
    save_dir_name = os.path.join(save_dir_root, run_date)
    #check if dir already exists
    if not os.path.isdir(save_dir_name):
        # prepare dir
        os.mkdir(save_dir_name, 0o666)
    return save_dir_name


def save_models(epoch, g_model, d_model, dir_name):
    filename = 'dis_model_%03d.h5' % (epoch + 1)
    d_model.save(os.path.join(dir_name, filename))
    print("Discriminator model saved")
    filename = 'gen_model_%03d.h5' % (epoch + 1)
    g_model.save(os.path.join(dir_name, filename))
    print("Generator model saved")
    
def save_hyperparams(dir_name):    
    with open(os.path.join(dir_name, "hyperparams.txt"), 'w') as f:
        print(hyperparams, file=f)
    print("Hyperparams saved")

def load_real_data_old():
    # load dataset
    (trainX, trainy), (_, _) = load_data()
    # expand to 3d, e.g. add channels
    X = expand_dims(trainX, axis=-1)
    # convert from ints to floats
    X = X.astype('float32')
    # scale from [0,255] to [-1,1]
    X = (X - 127.5) / 127.5
    return [X, trainy]


def load_real_data():
    # emulation dataset loading
    train_data_keras = datagen.flow_from_directory(directory=train_dir,
                                                   class_mode='categorical',
                                                   batch_size=1,  # 16 images at a time
                                                   target_size=(config.image_size, config.image_size),
                                                   color_mode='grayscale')  # Resize images
    # split into images and labels
    images, labels = next(train_data_keras)
    size = train_data_keras.samples
    images = numpy.zeros([size, images[0][0].size, images[0][0].size, 1])
    return [images, labels]


def get_images_and_labels(n_samples):
    train_data_keras = datagen.flow_from_directory(directory=train_dir,
                                                   class_mode='sparse',
                                                   batch_size=n_samples,  # 16 images at a time
                                                   target_size=(config.image_size, config.image_size),
                                                   color_mode='grayscale')  # Resize images
    # split into images and labels
    images, labels = next(train_data_keras)
    # labels = numpy.argmax(labels, axis=-1)
    # convert from ints to floats
    images = images.astype('float32')
    # scale from [0,255] to [-1,1]
    images = (images - 127.5) / 127.5
    # generate class labels
    return images, labels.astype(int)


def generate_real_samples(n_samples):
    images, labels = get_images_and_labels(n_samples)
    # generate class labels
    y = -ones((n_samples, 1))
    return images, labels, y


def get_noise(latent_dim, n_samples):
    return randn(latent_dim * n_samples).reshape(n_samples, latent_dim)


# use the generator to generate n fake examples, with class labels
def generate_fake_samples(generator, latent_dim, n_samples):
    img_input, labels_input = get_images_and_labels(n_samples)
    z_input = get_noise(latent_dim, n_samples)
    # predict outputs
    images = generator([img_input, labels_input, z_input])
    # create class labels
    y = ones((n_samples, 1))
    return images, labels_input, y


def save_plot(examples, epoch, n=10):
    # plot images
    for i in range(n * n):
        # define subplot
        pyplot.subplot(n, n, 1 + i)
        # turn off axis
        pyplot.axis('off')
        # plot raw pixel data
        pyplot.imshow(examples[i, :, :, 0], cmap='gray_r')
    # save plot to file
    filename = 'generated_plot_e%03d.png' % (epoch + 1)
    pyplot.savefig(filename)
    pyplot.close()


def gradient_penalty(discriminator, batch_size, real_images, fake_images, labels):
        """Calculates the gradient penalty.

        This loss is calculated on an interpolated image
        and added to the discriminator loss.
        """
        # Get the interpolated image
        alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)
        diff = fake_images - real_images
        interpolated = real_images + alpha * diff

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            # 1. Get the discriminator output for this interpolated image.
            pred = discriminator([interpolated, labels])

        # 2. Calculate the gradients w.r.t to this interpolated image.
        grads = gp_tape.gradient(pred, [interpolated])[0]
        # 3. Calculate the norm of the gradients.
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp


from enum import IntEnum


class FashionLabel(IntEnum):
    Tshirt = 0
    Trouser = 1
    Pullover = 2
    Dress = 3
    Coat = 4
    Sandal = 5
    Shirt = 6
    Sneaker = 7
    Bag = 8
    Ankle_boot = 9

class CancerLabel(IntEnum):
    akiec = 0
    bcc = 1
    bkl = 2
    df = 3
    mel = 4
    nv = 5
    vasc = 6

#Nets

In [None]:
import tensorflow as tf
from keras import backend
from keras.constraints import Constraint
from keras.layers import Input, Dense, Concatenate, ReLU, Conv2D, Conv2DTranspose, Reshape, LayerNormalization, \
    Activation, Embedding, LeakyReLU, Dropout
from keras.models import Model
from tensorflow import Tensor, keras


w_initializer = tf.keras.initializers.TruncatedNormal(mean=0., stddev=0.02)


def wasserstein_loss(y_true, y_pred):
    return backend.mean(y_true * y_pred)


# clip model weights to a given hypercube
class ClipConstraint(Constraint):
    # set clip value when initialized
    def __init__(self, clip_value):
        self.clip_value = clip_value

    # clip model weights to hypercube
    def __call__(self, weights):
        return backend.clip(weights, -self.clip_value, self.clip_value)

    # get the config
    def get_config(self):
        return {'clip_value': self.clip_value}


def up_scaling_layer(x, n_filters):
    kernel = 1
    stride = 2
    x = Conv2DTranspose(n_filters, kernel_size=kernel, strides=stride, padding='same')(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = LayerNormalization()(x)
    x = Dropout(config.dropout)(x)
    return x


def down_scaling_layer(x, n_filters):
    kernel = 1
    stride = 2
    x = Conv2D(n_filters, kernel_size=kernel, strides=stride, padding='same')(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = LayerNormalization()(x)
    x = Dropout(config.dropout)(x)
    return x


def resNet_block(x, n_filters, scaling):
    kernel = 3
    stride = 1
    x = Conv2D(n_filters, kernel_size=kernel, strides=stride, padding='same')(x)
    x = Conv2D(n_filters, kernel_size=kernel, strides=stride, padding='same')(x)
    x = Conv2D(n_filters, kernel_size=kernel, strides=stride, padding='same')(x)
    x = Conv2D(n_filters, kernel_size=kernel, strides=stride, padding='same')(x)
    x = LayerNormalization()(x)
    if scaling == 'up':
        x = up_scaling_layer(x, n_filters)
    else:
        x = down_scaling_layer(x, n_filters)
    return x


def representation_layer(layer, noise_input):
    filter = 3
    stride = 2
    kernel = 4

    noise_block_size = int(layer.shape[1]/2)

    dense = Dense(noise_block_size * noise_block_size * filter, kernel_initializer=w_initializer)(noise_input)

    reshape = Reshape((noise_block_size, noise_block_size, filter))(dense)

    noise_conv = Conv2DTranspose(kernel_size=kernel,
                                 strides=stride,
                                 filters=2 * filter,
                                 padding="same",
                                 kernel_initializer=w_initializer)(reshape)
    concat_layer = Concatenate()([layer, noise_conv])
    return concat_layer


def img_n_label_layer(input_shape, n_classes):
    # label input
    label_input_layer = Input(shape=(1,))
    # embedding for categorical input
    li = Embedding(n_classes, 50)(label_input_layer)
    # scale up to image dimensions with linear activation
    n_nodes = input_shape[0] * input_shape[1]
    li = Dense(n_nodes)(li)
    # reshape to additional channel
    li = Reshape((input_shape[0], input_shape[1], 1))(li)
    # image input
    img_input_layer = Input(shape=input_shape)
    # concat label as a channel
    img_n_label_input_layer = Concatenate()([img_input_layer, li])
    return img_input_layer, img_n_label_input_layer, label_input_layer


def define_generator(latent_dim, input_shape, n_classes):
    n_filters = 3

    img_input_layer, img_n_label_input_layer, label_input_layer = img_n_label_layer(input_shape, n_classes)

    x1 = resNet_block(img_n_label_input_layer, n_filters, 'down')
    x2 = resNet_block(x1, 2 * n_filters, 'down')
    x3 = resNet_block(x2, 4 * n_filters, 'down')
    x4 = resNet_block(x3, 8 * n_filters, 'down')

    noise_input_layer = Input(shape=latent_dim)
    decoded_img_and_noise = representation_layer(x4, noise_input_layer)

    x5 = resNet_block(decoded_img_and_noise, 8 * n_filters, 'up')
    skip_connection_3_5 = Concatenate(axis=-1)([x3, x5])
    x6 = resNet_block(skip_connection_3_5, 4 * n_filters, 'up')
    skip_connection_2_6 = Concatenate(axis=-1)([x2, x6])
    x7 = resNet_block(skip_connection_2_6, 2 * n_filters, 'up')
    skip_connection_1_7 = Concatenate(axis=-1)([x1, x7])
    x8 = resNet_block(skip_connection_1_7, 1, 'up')

    gen_output = Activation('tanh')(x8)

    model = Model(inputs=[img_input_layer, label_input_layer, noise_input_layer], outputs=gen_output)
    return model


def define_discriminator(input_shape, n_classes):
    img_input_layer, img_n_label_input_layer, label_input_layer = img_n_label_layer(input_shape, n_classes)

    model = DenseNet(img_input_layer, img_n_label_input_layer, label_input_layer, reduction=0.5, classes=1)

    opt = keras.optimizers.Adam(learning_rate=config.lr, beta_1=config.beta_1, beta_2=config.beta_2)
    model.compile(optimizer=opt, loss=wasserstein_loss, metrics=['accuracy'])
    return model


d_optimizer = keras.optimizers.Adam(learning_rate=config.lr, beta_1=config.beta_1, beta_2=config.beta_2)


# define the combined generator and discriminator model, for updating the generator
def define_gan(g_model, d_model):
    # make weights in the discriminator not trainable
    for layer in d_model.layers:
        if not isinstance(layer, LayerNormalization):
            layer.trainable = False
    # d_model.trainable = False
    # get img, label and noise inputs from generator model
    gen_img, gen_label, gen_noise = g_model.input
    # get image output from the generator model
    gen_output = g_model.output
    # connect image output and label input from generator as inputs to discriminator
    gan_output = d_model([gen_output, gen_label])
    # define gan model as taking img, noise and label and outputting a classification
    model = Model([gen_img, gen_label, gen_noise], gan_output)
    # compile model
    opt = keras.optimizers.Adam(learning_rate=config.lr, beta_1=config.beta_1, beta_2=config.beta_2)
    model.compile(loss=wasserstein_loss, optimizer=opt)
    return model

#Train

In [None]:
# example of training an conditional gan on the fashion mnist dataset
from matplotlib import pyplot
from numpy import ones, mean
import tensorflow as tf


# Create net
dataset = load_real_data()
image_shape = dataset[0][0].shape
discriminator = define_discriminator(image_shape, config.dataset_labels)
generator = define_generator(config.latent_dim, image_shape, config.dataset_labels)
gan = define_gan(generator, discriminator)

save_dir = create_save_dir()
save_hyperparams(save_dir)

#Start training
bat_per_epo = int(dataset[0].shape[0] / config.batch_size)
half_batch = int(config.batch_size / 2)

# lists for keeping track of loss
critic_loss_hist, gan_hist, gp_hist = list(), list(), list()
# calculate the number of training iterations
n_steps = bat_per_epo * config.epochs
real_images, real_labels, real_y = [], [], []
for i in range(n_steps):
    # update the critic more than the generator
    critic_loss_tmp, gp_tmp = list(), list()
    for _ in range(config.n_critic):
        with tf.GradientTape() as tape:

            real_images, real_labels, real_y = generate_real_samples(half_batch)
            real_logits = discriminator([real_images, real_labels])

            fake_images, fake_labels, fake_y = generate_fake_samples(generator, config.latent_dim, half_batch)
            fake_logits = discriminator([fake_images, fake_labels])

            d_cost = wasserstein_loss(real_logits, fake_logits)

            gp = gradient_penalty(discriminator, half_batch, real_images, fake_images, real_labels)

            d_loss = d_cost + gp * config.grad_weight

        critic_loss_tmp.append(d_loss)
        gp_tmp.append(gp)

        d_gradient = tape.gradient(d_loss, discriminator.trainable_variables)

        d_optimizer.apply_gradients(zip(d_gradient, discriminator.trainable_variables))

    critic_loss_hist.append(critic_loss_tmp)
    gp_hist.append(gp_tmp)
    # prepare points in latent space as input for the generator
    img_input, labels_input, y_gan = real_images, real_labels, real_y
    # img_input, labels_input, y_gan = Utils.generate_real_samples(n_batch)
    z_input = get_noise(config.latent_dim, half_batch)
    # update the generator via the critic's error
    g_loss = gan.train_on_batch([img_input, labels_input, z_input], y_gan)
    gan_hist.append(g_loss)
    # summarize loss on this batch
    print('>%d/%d, critic_loss=%.3f, gradient_penalty=%.3f ,gan_loss=%.3f' % (i + 1, n_steps, critic_loss_hist[-1], gp_hist[-1], g_loss))

    metrics = {
        "gradient_penalty": gp_hist[-1],
        "critc_loss": critic_loss_hist[-1],
        "generator_loss": g_loss
    }

    # Log train metrics to wandb
    wandb.log(metrics)
    # evaluate the model performance every 'epoch'
    if (i + 1) % bat_per_epo == 0:
        save_models(i, generator, discriminator, save_dir)
wandb.finish()