In [None]:
import os

In [None]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID";
os.environ["CUDA_VISIBLE_DEVICES"] = "5"

# <font color='red'>**Data loader**</font>
## **Useful libraries**

In [None]:
import scipy
from glob import glob
import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt
import cv2

In [None]:
#TALE SECOND PART:
class DataLoader():      
    
    """ Data loader method: loader initialization on training or test batch

    Parameters
    -------------------   
    dataset_name: string- dataset name
    img_res: array- image shape (n_rows, n_cols,n_channels)
    path_csv: string- path csv (que contiene el csv?, es importante?***)
    use_test_in_batch: boolean- decides if the data is test or not
    normalize: boolean- for image normalization
    """
    #csv dataset has three columns: Emotion: (label from 0 to 6)
                              #   Pixels: (pixel values from 48X48 image)
                              #   Usage: split to be used (train or test) 
    def __init__(self, dataset_name, img_res=(256, 256, 3), path_csv=None, use_test_in_batch=False, 
                 normalize=True):
        
        self.dataset_name = dataset_name
        self.img_res = img_res
        # images and labels vectors for train and test 
        self.img_vect_train = None 
        self.img_vect_test = None 
        self.lab_vect_train = None 
        self.lab_vect_test = None 
        self.path_csv = path_csv 
        ## labels dict
        self.lab_dict = {0: "Adenoma", 1: "Hyperplastic" , 2: "Serrated"}
        self.use_test_in_batch = use_test_in_batch
        self.normalize = normalize 
        ## load dataset 
        self._load_internally()


    def _load_internally(self):

        """

        """

        print(">> loading "+str(self.dataset_name)+" ...") 
        #reading csv dataset
        if self.dataset_name == 'fer2013': #change dataset***
            if self.path_csv is None:
                raw_data = pd.read_csv('../ccycleGanCrc.csv', header=None)
                raw_data.columns = ["emotion", "pixels", "Usage"]
            else:
                raw_data = pd.read_csv(self.path_csv)
        else:
            raise Exception("dataset not supported:"+str(self.dataset_name))

        #reading train and test split 
        n_train = np.sum(raw_data['Usage'] == 'Training')
        n_test = np.sum(raw_data['Usage'] != 'Training')
        assert n_train + n_test == len(raw_data)

        #"batch" of training and test data (#train/test samples, img_w, img_h, img_ch, dataType)
        self.img_vect_train = np.zeros( (n_train, self.img_res[0], self.img_res[1],
                                         self.img_res[2]), 'float32')
        self.img_vect_test = np.zeros( (n_test, self.img_res[0], self.img_res[1],
                                        self.img_res[2]), 'float32')
        self.lab_vect_train = np.zeros(n_train, 'int32')
        self.lab_vect_test = np.zeros(n_test, 'int32')

        i_train , i_test = 0,0
        #pass throught all data
        print("passing throught all data...")
        for i in range(len(raw_data)):
            
            #get pixels for i data
            img = raw_data["pixels"][i] 
            x_pixels = np.array(img.split(" "), 'float32')
            #normalize
            if self.normalize:
                x_pixels = x_pixels/127.5 - 1.
            #reshape into image matrix
            x_pixels = x_pixels.reshape(self.img_res)
            #get set (train or test)
            us = raw_data["Usage"][i]
            #save into image vect set for training or test
            if us == 'Training':            
                self.img_vect_train[i_train] = x_pixels
                self.lab_vect_train[i_train] = int(raw_data["emotion"][i]) 
                i_train = i_train + 1
            else:
                self.img_vect_test[i_test] = x_pixels
                self.lab_vect_test[i_test] = int(raw_data["emotion"][i]) 
                i_test = i_test + 1

        #for check 
        assert i_train == len(self.img_vect_train) 
        assert i_train == len(self.lab_vect_train) 
        assert i_test == len(self.lab_vect_test) 
        assert i_test == len(self.img_vect_test) 

        print("> loaded train:",len(self.img_vect_train),"   - test:",len(self.lab_vect_test) )
       
        #when we use test data
        print("info de use_test_in_batch: ", self.use_test_in_batch)
        if self.use_test_in_batch:
            #revisar por que no esta el metodo leo_lab
            self.lab_vect_train = np.concatenate([self.lab_vect_train, self.lab_vect_test, self.leo_lab])
            self.img_vect_train = np.concatenate([self.img_vect_train, self.img_vect_test, self.leo])

    def load_leo(self):
        """Return label and image from reading
        """
        return self.leo_lab , self.leo

    def load_data(self, domain=None, batch_size=1, is_testing=False, convertRGB=False):
        """Load data function: load batch of data

        Parameters
        ------------
        domain: int- class label 
        batch_size: int- 
        is_testing: boolean- test or not
        convertRGB: boolean- to make RGB images

        Return
        ------------
        labels and images batch
        """
        if is_testing:
            #when label class was not given
            if domain is None:
                idx = np.random.choice(self.img_vect_test.shape[0], size=batch_size)
            else:                
                assert domain in [0,1,2,3,4,5,6]# for check that label given is correct
                idx0 = np.argwhere(self.lab_vect_test == domain)#get shape of data with label to work 
                idx1 = np.random.choice(idx0.shape[0], size=batch_size)#random choice
                idx = idx0[idx1]#from general data with the label we get the random data selected
                idx = np.squeeze(idx)#check size dimensions of idx***
            batch_images = self.img_vect_test[idx]
            labels = self.lab_vect_test[idx]
            #same for train data
        else:
            if domain is None:
                idx = np.random.choice(self.lab_vect_train.shape[0],size=batch_size)
            else:                
                assert domain in [0,1,2,3,4,5,6]
                idx0 = np.argwhere(self.lab_vect_train == domain) 
                idx1 = np.random.choice(idx0.shape[0],size=batch_size)
                idx = idx0[idx1]
                idx = np.squeeze(idx)
            batch_images = self.img_vect_train[idx]
            labels = self.lab_vect_train[idx]

        batch_images = np.resize(batch_images, (batch_size, self.img_res[0], self.img_res[1],
                                self.img_res[2]))

        if convertRGB:            
            _batch_images = np.zeros((batch_size, self.img_res[0], self.img_res[1], 3))
            for i in range(batch_size):
                _batch_images[i] = cv2.cvtColor(batch_images[i], cv2.COLOR_GRAY2RGB)
            batch_images = _batch_images

        if is_testing:
            return labels , batch_images
        for i in range(batch_size):
            if np.random.random() > 0.5:#check its meaning***
                batch_images[i] = np.fliplr(batch_images[i]) #for column flip (its needed?***)
        return labels , batch_images

    def load_batch(self, domain=None, batch_size=1, is_testing=False , convertRGB=False):
        """
        Parameters:
        --------------
        domain: int- label class
        batch_size: int- amount of images to be treated
        is_testing: boolean- for testing pourposes
        convertRGB: boolean- for get RGB images

        Return:
        --------------
        labels and their respective batch images
        """
        if is_testing:
            raise Exception("not supported")
        self.n_batches = int(len(self.img_vect_train) / batch_size)
        total_samples = self.n_batches * batch_size
        for i in range(self.n_batches):                       
            if domain is None:
                idx = np.random.choice(self.lab_vect_train.shape[0], size=batch_size)
            else:               
                assert domain in list(range(7))
                idx0 = np.argwhere(self.lab_vect_train == domain) 
                idx1 = np.random.choice(idx0.shape[0], size=batch_size)
                idx = idx0[idx1]
                idx = np.squeeze(idx)
            batch_images = self.img_vect_train[idx]
            labels = self.lab_vect_train[idx]
            for i in range(batch_size):
                if np.random.random() > 0.5:#check its meaning***
                    batch_images[i] = np.fliplr(batch_images[i]) #for column flip (its needed?***)
            batch_images = np.resize(batch_images, (batch_size,self.img_res[0],self.img_res[1],self.img_res[2]))
            if convertRGB:
                _batch_images = np.zeros((batch_size, self.img_res[0], self.img_res[1],3))
                for i in range(batch_size):
                    _batch_images[i] = cv2.cvtColor(batch_images[i], cv2.COLOR_GRAY2RGB)
                batch_images = _batch_images
            yield labels , batch_images

            #Nota: no entiendo muy bien la diferencia entre los metodos load_batch y load_data***


    def load_batch_AB(self, domain=None, batch_size=1, is_testing=False):
        """Load batch of data from two domains (A and B)
        Parameters:
        ---------------
        domain: array- labels class
        batch_size: int- amount of data to be loaded
        is_testing: boolean- for testing pourposes

        Return:
        ---------------
        Batch images from domains A and B
        Respective labels for data from both domains
        """
        if is_testing:#it seems to no support testing (make is_testing=False always?***)
            raise Exception("not supported")
        self.n_batches = int(len(self.img_vect_train) / batch_size)
        total_samples = self.n_batches * batch_size
        for i in range(self.n_batches):            
            assert domain is not None #check if domain is not empty 
            assert type(domain) is list #domain type must be list format
            #check both domains belong to labels between [0,6] 
            assert domain[0] in list(range(7))
            assert domain[1] in list(range(7))
            assert domain[0] != domain[1]#check different domains
            domain_A , domain_B = domain[0] , domain[1]
            # domain_A
            idx0 = np.argwhere(self.lab_vect_train == domain_A) 
            idx1 = np.random.choice(idx0.shape[0],size=batch_size)
            idx = idx0[idx1]
            idx = np.squeeze(idx)
            batch_images_A = self.img_vect_train[idx]
            labels_A = self.lab_vect_train[idx]
            for i in range(batch_size):
                if np.random.random() > 10.5:#check its meaning***
                    batch_images_A[i] = np.fliplr(batch_images_A[i])#for column flip (its needed?***)
            batch_images_A = np.resize(batch_images_A, (batch_size,self.img_res[0],self.img_res[1],self.img_res[2]))
            # domain_B
            idx0 = np.argwhere(self.lab_vect_train == domain_B) 
            idx1 = np.random.choice(idx0.shape[0],size=batch_size)
            idx = idx0[idx1]
            idx = np.squeeze(idx)
            batch_images_B = self.img_vect_train[idx]
            labels_B = self.lab_vect_train[idx]
            for i in range(batch_size):
                if np.random.random() > 10.5:#check its meaning***
                    batch_images_B[i] = np.fliplr(batch_images_B[i])#for column flip (its needed?***)
            batch_images_B = np.resize(batch_images_B, (batch_size,self.img_res[0],self.img_res[1],self.img_res[2]))

            yield labels_A , batch_images_A , labels_B , batch_images_B

# <font color='red'>**Models**</font>
## **Useful libraries**

In [None]:
!pip install git+https://www.github.com/keras-team/keras-contrib.git

In [None]:
from __future__ import print_function, division
import scipy

from keras.datasets import mnist
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
from keras.layers import Reshape
import datetime
import sys
from keras.layers import Concatenate, Dense, LSTM, Input, concatenate
import numpy as np
import os
import random 
from keras.layers import Conv2DTranspose, BatchNormalization
import tensorflow as tf 

from tensorflow.keras.utils import to_categorical

### Useful methods

In [None]:
def get_dim_conv(dim,f,p,s):
    """Function to calculate output conv shape
    Parameters:
    -----------
    dim: 
    f: int- amount of filters
    p: int- padding
    s: int- stride 

    Return:
    -----------
    new output dimension
    """
    return int((dim+2*p-f)/2+1)

In [None]:
def build_generator_enc_dec(img_shape, gf, num_classes, channels, num_layers=6, f_size=4, 
                            tranform_layer=False):
    """U-Net Generator
    Parameters:
    ------------
    img_shape: array- image shape
    gf: int- amount of filters check***
    num_classes: int- classes to be taken into account
    num_layers: int- number of layers in net
    f_size: int- filter size
    transform_layer: boolean- 

    Return: enconder and decoder nets
    """


    def conv2d(layer_input, filters, f_size=f_size):
        """Layers used during downsampling"""
        d = Conv2D(filters, kernel_size=f_size, strides=2, padding='valid')(layer_input)
        d = LeakyReLU(alpha=0.2)(d)
        d = InstanceNormalization()(d)
        return d    


    def deconv2d(layer_input, skip_input, filters, f_size=f_size, dropout_rate=0, output_padding=None):
        """Layers used during upsampling"""

        u = Conv2DTranspose(filters=filters, kernel_size=f_size, strides=2, activation='relu',
                            output_padding=output_padding)(layer_input)

        if dropout_rate:
            u = Dropout(dropout_rate)(u)
        u = InstanceNormalization()(u)
        u = Concatenate()([u, skip_input])
        return u

    # Image input layer
    img = Input(shape=img_shape)

    # Downsampling
    d = img 
    zs = [] 
    dims = []
    _dim = img_shape[0]
    for i in range(num_layers):
        d = conv2d(d, gf*2**i)# add by 2 as we go deeper in the net
        zs.append(d)
        _dim = get_dim_conv(_dim,f_size,0,2)
        dims.append((_dim,gf*2**i))
        print("D:",_dim,gf*2**i)

    ######################## here is the problem (block section for new lines)
    zs.pop()#remove last out: (2,2,2048)
    d = MaxPool2D(pool_size=(2,2))(d)
    zs.append(d)#add (1,1,2048)
    ######################## final new lines
    G_enc = Model(img,zs)#encoder net
    print("*** generator enconder ok***!")

    _zs = [] 
    d_ , c_ = dims.pop()
    i_ = Input(shape=(d_, d_, c_))
    #two new lines
    i_ = MaxPool2D(pool_size=(2,2))(i_)
    _zs.append(i_)
    label = Input(shape=(num_classes,), dtype='float32')
    label_r = Reshape((1,1,num_classes))(label)

    u = concatenate([i_, label_r],axis=-1)

    ## transf (why?***)
    if tranform_layer:
        tr = Flatten()(u)
        tr = Dense(c_+num_classes)(tr)
        tr = LeakyReLU(alpha=0.2)(tr)
        u = Reshape((1,1,c_+num_classes))(tr)
    ##
    u = Conv2D(c_, kernel_size=1, strides=1, padding='valid')(u) ## 1x1 conv 

    # Upsampling
    for i in range(num_layers-1):
        _ch = gf*2**((num_layers-2)-i)
        d_ , c_ = dims.pop()
        print(i,d_,c_)
        i_ = Input(shape=(d_, d_, c_))
        _zs.append(i_)
        if i == 4:
            u = Conv2DTranspose(filters=_ch, kernel_size=5, strides=2, activation='relu', 
                                output_padding=None)(u)

            u = InstanceNormalization()(u)
            u = Concatenate()([u, i_])
        elif i==0:
            u = Conv2DTranspose(filters=_ch, kernel_size=6, strides=2, activation='relu',
                            output_padding=None)(u)

            u = InstanceNormalization()(u)
            u = Concatenate()([u, i_])

        else:
            u = deconv2d(u, i_, _ch)

    u = Conv2DTranspose(filters=channels, kernel_size=f_size, strides=2, activation='tanh', output_padding=None)(u)


    _zs.reverse()
    _zs.append(label)
    G_dec = Model(_zs,u) #decoder net

    return G_enc , G_dec

In [None]:
def build_discriminator(img_shape, df, num_classes, num_layers=6, act_multi_label='softmax'):
    """Build discriminator function: net for discriminate real from fake data
      Parameters:
      -----------
      img_shape: array- (w,h,c)
      df: int- dimension filters check***
      num_layers: int- amount of model's layers
      act_multi_label: string- activation function

      Return: discriminator model
      """
    

    def d_layer(layer_input, filters, f_size=4, normalization=True):
        """Discriminator layer"""
        d = Conv2D(filters, kernel_size=f_size, strides=2, padding='valid')(layer_input)
        d = LeakyReLU(alpha=0.2)(d)
        if normalization:
            d = InstanceNormalization()(d)
        return d
    
    img = Input(shape=img_shape)

    d = img 
    for i in range(num_layers):
        #normalize all layers except the 1st one
        _norm = False if i == 0 else True 
        filt = df*2**i
        d = d_layer(d, filt, normalization=_norm)

    d = MaxPool2D(pool_size=(2,2))(d)
    flat_repr = Flatten()(d)#flat representation of the last layer

    #validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)

    print("flat_repr.get_shape().as_list():",flat_repr.get_shape().as_list())
    print("flat_repr.get_shape().as_list()[1:]:",flat_repr.get_shape().as_list()[1:])

    #Dense neural net
    #Part to address the real or fake discrimination
    gan_logit = Dense(df*2**(num_layers-1))(flat_repr)
    gan_logit = LeakyReLU(alpha=0.2)(gan_logit)
    gan_prob = Dense(1, activation='sigmoid')(gan_logit)

    #Part to address the class classification
    class_logit = Dense(df*2**(num_layers-1))(flat_repr)
    class_logit = LeakyReLU(alpha=0.2)(class_logit)
    class_prob = Dense(num_classes, activation=act_multi_label)(class_logit)


    return Model(img, [gan_prob, class_prob])

# <font color='red'>**Conditional cycleGan network**</font>
## **Useful libraries**

In [None]:
from __future__ import print_function, division
import scipy

from tensorflow.keras.datasets import mnist
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate, MaxPool2D
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from keras.models import Sequential, Model
#from keras.optimizers import Adam
from tensorflow.keras.optimizers import Adam
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
from keras.layers import Reshape
import datetime
import matplotlib.pyplot as plt
import sys
#from data_loader import DataLoader
from keras.layers import Concatenate, Dense, LSTM, Input, concatenate
import numpy as np
import pandas as pd 
import os
import random 

import tensorflow as tf 

from tensorflow.keras.utils import to_categorical
import argparse
from sklearn.metrics import accuracy_score

#from  models import *

In [None]:
class CCycleGAN():
    """ Conditional cycleGan: model initialization (generator and discriminator nets) and training,
    receive image shape,
    amount of classes to be taken into account, 
    weight losses for generator and discriminator nets,
    load the dataset.

    Parameters
    ------------
    img_rows and img_cols: int- rows and cols for image to work with
    channels: int- amount of image channels
    num_classes: int- amount of classes to be taken into account
    d_gan_loss_w: int- discriminator loss weight
    d_cl_loss_w: int- discriminator loss weight for class tag
    g_gan_loss_w: int- generator loss weight
    g_cl_loss_w: int- generator loss weight for class tag
    ---> rec_loss_w: int- cycle consistency loss weight (check)
    adam_lr: float- learning rate
    adam_beta_1: float- parameters for adam rule
    adam_beta_2: float- parameters for adam rule 
    """

    #values assignment
    def __init__(self,img_rows = 256, img_cols = 256, channels = 3, num_classes=3, d_gan_loss_w=1,
      d_cl_loss_w=1, g_gan_loss_w=1, g_cl_loss_w=1, rec_loss_w=1, adam_lr=0.0002, adam_beta_1=0.5,
      adam_beta_2=0.999):
        
        # Input shape
        self.img_rows = img_rows
        self.img_cols = img_cols
        self.channels = channels
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.num_classes = num_classes

        # Loss weights 
        self.d_gan_loss_w = d_gan_loss_w
        self.d_cl_loss_w = d_cl_loss_w
        self.g_gan_loss_w = g_gan_loss_w
        self.g_cl_loss_w = g_cl_loss_w
        self.rec_loss_w = rec_loss_w

        # optmizer params 
        self.adam_lr = adam_lr
        self.adam_beta_1 = adam_beta_1
        self.adam_beta_2 = adam_beta_2

        # Configure data loader
        self.dataset_name = 'fer2013' #maybe changed dataset name (ver importancia del nombre?***)
        # TALE SECOND PART (pasar a file: data_loader.py)
        self.data_loader = DataLoader(dataset_name=self.dataset_name, img_res=self.img_shape,
                                      use_test_in_batch=False)
        # label dict
        self.lab_dict = {0: "Ade", 1: "Hyp" , 2: "Ser"}

        # Number of filters in the first layer of Generator and Discriminator
        self.gf = 32
        self.df = 64

        optimizer = Adam(self.adam_lr, self.adam_beta_1, self.adam_beta_2) 

        # Build and compile the discriminators (models.py method)
        self.d = build_discriminator(img_shape=self.img_shape, df=64, num_classes=self.num_classes,
                                    act_multi_label='softmax')
        print("******** Discriminator/Classifier ********")
        self.d.summary()
        self.d.compile(loss=['binary_crossentropy',  # gan
                             'binary_crossentropy'   # class
                             ],
                        optimizer=optimizer,
                        metrics=['accuracy'],
                        loss_weights=[
                        self.d_gan_loss_w , # gan
                        self.d_cl_loss_w   # class
                        ])

        #-------------------------
        # Construct Computational
        #   Graph of Generators
        #-------------------------

        # Build the generators (here i go#1)
        self.g_enc , self.g_dec = build_generator_enc_dec(img_shape=(256, 256, 3), gf=64, num_classes=3, 
                                                          channels=3, tranform_layer=True)
        print("******** Generator_ENC ********")
        self.g_enc.summary()
        print("******** Generator_DEC ********")
        self.g_dec.summary()

        # Input images from both domains
        print("***** images from domains *****")
        img = Input(shape=self.img_shape)
        label0 = Input(shape=(self.num_classes,))
        label1 = Input(shape=(self.num_classes,))

        # Translate images to the other domain
        z1,z2,z3,z4,z5,z6 = self.g_enc(img)
        fake = self.g_dec([z1,z2,z3,z4,z5,z6,label1])

        # Translate images back to original domain
        reconstr = self.g_dec([z1,z2,z3,z4,z5,z6,label0])

        # For the combined model we will only train the generators (why?)
        self.d.trainable = False

        # Discriminators determines validity of translated images gan_prob,
        # class_prob [label,img], [gan_prob,class_prob]
        gan_valid , class_valid = self.d(fake)

        # Combined model trains generators to fool discriminators
        self.combined = Model(inputs=[img,label0,label1], outputs=[ gan_valid, class_valid, reconstr])
        self.combined.compile(loss=['binary_crossentropy','categorical_crossentropy', 'mae'],
                              loss_weights=[                                      
                                            self.g_gan_loss_w, # g_loss gan 
                                            self.g_cl_loss_w, # g_loss class  
                                            self.rec_loss_w # reconstruction loss
                                          ],
                            optimizer=optimizer)

        print("******** Combined model ********")
        self.combined.summary()

    def generate_new_labels(self,labels0):
        labels1 = [] 
        for i in range(len(labels0)):
            allowed_values = list(range(0, self.num_classes))
            allowed_values.remove(labels0[i])
            labels1.append(random.choice(allowed_values))
        return np.array(labels1,'int32')

    def generate_new_labels_all(self, labels0):
        #called from training procedure check***
        """Function for keep label values different from original labels
        Parameter:
        labels0: array- real label class list
        Return: array with all labels different from original label class
        """
        labels_all = [] 
        for i in range(len(labels0)):
            allowed_values = list(range(0, self.num_classes))
            allowed_values.remove(labels0[i])
            labels_all.append(np.array(allowed_values,'int32'))
        return np.array(labels_all,'int32')

    def train(self, epochs, batch_size=1, sample_interval=50 , d_g_ratio=5):
        """Conditional cycleGan training function
        Parameters:
        ------------
        epochs: int- amount of epochs to train model
        batch_size: int- number of samples to be taken into account for each update step
        sample_interval: int- check***
        d_g_ratio: int- epoch frequency for decay learning rate check***
        """

        start_time = datetime.datetime.now()
        # logs 
        epoch_history, batch_i_history,  = [], []   
        d_gan_loss_history, d_gan_accuracy_history, d_cl_loss_history, d_cl_accuracy_history = [], [], [], [] 
        g_gan_loss_history, g_cl_loss_history = [] , [] 
        reconstr_history = [] 

        # Adversarial loss ground truths
        valid = np.ones((batch_size,1) )
        fake = np.zeros((batch_size,1) )

        null_labels = np.zeros((batch_size,3) )

        for epoch in range(epochs):
            for batch_i, (labels0 , imgs) in enumerate(self.data_loader.load_batch(batch_size=batch_size)):
                labels1_all = self.generate_new_labels_all(labels0)

                labels0_cat = to_categorical(labels0, num_classes=self.num_classes)
                #
                labels1_all_1 = to_categorical(labels1_all[:,0], num_classes=self.num_classes)
                labels1_all_2 = to_categorical(labels1_all[:,1], num_classes=self.num_classes)

                # ----------------------
                #  Train Discriminators
                # ----------------------

                # Translate images to opposite domain
                zs1,zs2,zs3,zs4,zs5,zs6 = self.g_enc.predict(imgs)#check what encoder returns***
                fakes_1 = self.g_dec.predict([zs1,zs2,zs3,zs4,zs5,zs6,labels1_all_1])
                fakes_2 = self.g_dec.predict([zs1,zs2,zs3,zs4,zs5,zs6,labels1_all_2])

                # Train the discriminators (original images = real / translated = Fake)
                idx = np.random.permutation(self.num_classes*labels0.shape[0])
                _labels_cat = np.concatenate([labels0_cat,                                        
                                              null_labels,
                                              null_labels])
                _imgs = np.concatenate([imgs,
                                        fakes_1,
                                        fakes_2])
                
                _vf = np.concatenate([valid, fake, fake])
                _labels_cat = _labels_cat[idx]
                _imgs = _imgs[idx]
                _vf = _vf[idx]

                d_loss  = self.d.train_on_batch(_imgs, [_vf,_labels_cat])

                if batch_i % d_g_ratio == 0:
                    # ------------------
                    #  Train Generators
                    # ------------------
                    _imgs = np.concatenate([imgs,                                                     
                                          imgs])

                    _labels0_cat = np.concatenate([labels0_cat,                                                               
                                                labels0_cat])

                    _labels1_all_other = np.concatenate([labels1_all_1,                                                                                
                                                      labels1_all_2])

                    # I know this should be outside the loop;
                    # left here to make code more understandable 
                    _valid = np.concatenate([valid,                                                 
                                          valid])

                    idx = np.random.permutation((self.num_classes-1)*labels0.shape[0])
                    _imgs = _imgs[idx]
                    _labels0_cat = _labels0_cat[idx]
                    _labels1_all_other = _labels1_all_other[idx]
                    _valid = _valid[idx]

                    # Train the generators
                    g_loss = self.combined.train_on_batch([_imgs, _labels0_cat, _labels1_all_other],
                                                          [_valid, _labels1_all_other, _imgs])

                    elapsed_time = datetime.datetime.now() - start_time

                    print("[Epoch %d/%d] [Batch %d/%d] [D_gan loss: %f, acc_gan: %3d%%] [D_cl loss: %f, acc_cl: %3d%%] [G_gan loss: %05f, G_cl: %05f, recon: %05f] time: %s " \
                      % ( epoch, epochs,
                          batch_i, self.data_loader.n_batches,
                          d_loss[1],100*d_loss[3],d_loss[2],100*d_loss[4],
                          g_loss[1],g_loss[2],g_loss[3],
                          elapsed_time))

                    # log
                    epoch_history.append(epoch) 
                    batch_i_history.append(batch_i)
                    d_gan_loss_history.append(d_loss[1])
                    d_gan_accuracy_history.append(100*d_loss[3])
                    d_cl_loss_history.append(d_loss[2])
                    d_cl_accuracy_history.append(100*d_loss[4])
                    g_gan_loss_history.append(g_loss[1])
                    g_cl_loss_history.append(g_loss[2])
                    reconstr_history.append(g_loss[3])

                # If at save interval => save generated image samples
                if batch_i % sample_interval == 0:
                    self.sample_images(epoch, batch_i)
                    #self.sample_images(epoch, batch_i,use_leo=True)

                    train_history = pd.DataFrame({
                        'epoch': epoch_history, 
                        'batch': batch_i_history, 
                        'd_gan_loss': d_gan_loss_history, 
                        'd_gan_accuracy' : d_gan_accuracy_history,
                        'd_cl_loss': d_cl_loss_history, 
                        'd_cl_accuracy': d_cl_accuracy_history, 
                        'g_gan_loss': g_gan_loss_history, 
                        'g_cl_loss': g_cl_loss_history, 
                        'reconstr_loss': reconstr_history
                    })
                    train_history.to_csv(str(sys.argv[0]).split('.')[0]+'_train_log.csv',index=False)
                    
                    #new lines
                    file_name = '../checkPoints/batch' + str(batch_i) + '.h5'
                    self.combined.save(file_name)
                    print("model saved!")
                    

    def sample_images(self, epoch, batch_i, use_leo=False):
        """Function to save a batch test samples
        Parameters:
        ------------
        epoch: int- epoch where the interval save is done
        batch_i: int- number of batch where we want to save
        """
        ## disc
        labels0_d , imgs_d = self.data_loader.load_data(batch_size=64, is_testing=True)
        #predicting images with discriminator net
        gan_pred_prob, class_pred_prob = self.d.predict(imgs_d)

        gan_pred = (gan_pred_prob > 0.5)*1.0
        gan_pred = gan_pred.reshape((64,))

        class_pred = np.argmax(class_pred_prob,axis=1)

        gan_test_accuracy = accuracy_score(y_true=np.ones(64), y_pred=gan_pred)
        class_test_accuracy = accuracy_score(y_true=labels0_d, y_pred=class_pred)

        print("*** TEST *** [D_gan accuracy :",gan_test_accuracy,"] [D_cl accuracy :", class_test_accuracy,"]")

        ## gen         
        if use_leo:
            labels0_ , imgs_ = self.data_loader.load_leo()#load_leo() why?***
        else:
            labels0_ , imgs_ = self.data_loader.load_data(batch_size=1, is_testing=True)
        labels1_all = self.generate_new_labels_all(labels0_)

        labels0_cat = to_categorical(labels0_, num_classes=self.num_classes)
        labels1_all_1 = to_categorical(labels1_all[:,0], num_classes=self.num_classes)
        labels1_all_2 = to_categorical(labels1_all[:,1], num_classes=self.num_classes)
        
        # Translate images 
        zs1_,zs2_,zs3_,zs4_, zs5_, zs6_ = self.g_enc.predict(imgs_)
        fake_1 = self.g_dec.predict([zs1_,zs2_,zs3_,zs4_,zs5_,zs6_,labels1_all_1])
        fake_2 = self.g_dec.predict([zs1_,zs2_,zs3_,zs4_,zs5_,zs6_,labels1_all_2])

        # Reconstruct image 
        reconstr_ = self.g_dec.predict([zs1_,zs2_,zs3_,zs4_,zs5_,zs6_,labels0_cat])

        gen_imgs = np.concatenate([imgs_,                              
                                  fake_1, 
                                  fake_2,
                                  reconstr_])

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5# check the rol of 0.5***

        titles = ['Orig:'+str(self.lab_dict[labels0_.item(0)]), 
                  'Trans:'+str(self.lab_dict[labels1_all[:,0].item(0)]),
                  'Trans:'+str(self.lab_dict[labels1_all[:,1].item(0)]),
                  'Reconstr.']
        r, c = 1, 4#for rows and cols
        fig, axs = plt.subplots(r, c, figsize=(16,16), squeeze=False)

        plt.subplots_adjust(hspace=0)

        if not os.path.exists( "images/%s/"% (self.dataset_name)):
            os.makedirs( "images/%s/"% (self.dataset_name)  )

        cnt = 0
        for i in range(r):
            for j in range(c):
                imagen = gen_imgs[cnt]
                #imagen = imagen.reshape(self.img_rows, self.img_cols, self.channels)
                #print("reshape correcto!")
                axs[i,j].imshow(imagen)
                axs[i,j].set_title(titles[cnt])
                axs[i,j].axis('off')
                cnt += 1

        if use_leo:
            fig.savefig("images/%s/%d_%d_leo.png" % (self.dataset_name, epoch, batch_i))
        else:
            fig.savefig("images/%s/%d_%d.png" % (self.dataset_name, epoch, batch_i))
        plt.close()

# <font color='red'>Main</font>

In [None]:
d_gan_loss_w =1
d_cl_loss_w =1
g_gan_loss_w =2
g_cl_loss_w =2
rec_loss_w =1
adam_lr =0.0002
adam_beta_1 =0.5
adam_beta_2 =0.999
epochs =170
batch_size =8
sample_interval =200

In [None]:
# CCycleGAN: THE TALE START HERE
gan = CCycleGAN(d_gan_loss_w=d_gan_loss_w, d_cl_loss_w=d_cl_loss_w,            
                g_gan_loss_w=g_gan_loss_w, g_cl_loss_w=g_cl_loss_w,
                rec_loss_w=rec_loss_w, adam_lr=adam_lr,
                adam_beta_1=adam_beta_1, adam_beta_2=adam_beta_2)

In [None]:
gan.train(epochs=epochs, batch_size=batch_size, sample_interval=sample_interval)