### CapsuleGAN

The jupyter notebook can be directly executed on Google Colab. The code would take about 4-5 hours to run on colab with the new T4 GPU. Make sure to run all the commands in sequence. 

***Add the dataset zip file as datasets.zip to google drive before execution***

- The code mounts the drive and imports the dataset. 
- The code also creates 4 directories and text file which can be easily downloaded from code below. 
- The images2 aves the outputs. 
- Rest of the folders were created to obtain the original image and output them along with the main output. 
- The text file contains all the Loss's calculated during the run. 


In [0]:
!pip install git+https://www.github.com/keras-team/keras-contrib.git

In [0]:
from google.colab import drive
drive.mount('/content/gdrive')

In [0]:
!unzip gdrive/My\ Drive/datasets.zip

In [0]:
import scipy
from glob import glob
import numpy as np

np.random.seed(1234)


class DataLoader:
    def __init__(self, dataset_name, img_res=(128, 128)):
        self.dataset_name = dataset_name
        self.img_res = img_res

    def load_data(self, domain, batch_size=1, is_testing=False):
        data_type = "train%s" % domain if not is_testing else "test%s" % domain
        path = glob('./datasets/%s/%s/*' % (self.dataset_name, data_type))

        batch_images = np.random.choice(path, size=batch_size)

        imgs = []
        for img_path in batch_images:
            img = self.imread(img_path)
            if not is_testing:
                img = scipy.misc.imresize(img, self.img_res)

                if np.random.random() > 0.5:
                    img = np.fliplr(img)
            else:
                img = scipy.misc.imresize(img, self.img_res)
            imgs.append(img)
            
            if domain == "A":
              img = self.imread(img_path.replace("trainA", "trainB"))  
              if not is_testing:
                img = scipy.misc.imresize(img, self.img_res)

                if np.random.random() > 0.5:
                    img = np.fliplr(img)
              else:
                img = scipy.misc.imresize(img, self.img_res)
              imgs.append(img)
    
        imgs = np.array(imgs) / 127.5 - 1.

        return imgs

    def load_batch(self, batch_size=1, is_testing=False):
        data_type = "train" if not is_testing else "val"
        path_A = glob('./datasets/%s/%sA/*' % (self.dataset_name, data_type))
        path_B = glob('./datasets/%s/%sB/*' % (self.dataset_name, data_type))
        print("PathA images count: {} PathB images count: {}".format(len(path_A), len(path_B)))
        self.n_batches = int(min(len(path_A), len(path_B)) / batch_size)
        total_samples = self.n_batches * batch_size

        # Sample n_batches * batch_size from each path list so that model sees all
        # samples from both domains
        path_A = np.random.choice(path_A, total_samples, replace=False)
        path_B = np.random.choice(path_B, total_samples, replace=False)

        for i in range(self.n_batches - 1):
            batch_A = path_A[i * batch_size:(i + 1) * batch_size]
            batch_B = path_B[i * batch_size:(i + 1) * batch_size]
            imgs_A, imgs_B = [], []
            for img_A, img_B in zip(batch_A, batch_B):
                img_A = self.imread(img_A)
                img_B = self.imread(img_B)

                img_A = scipy.misc.imresize(img_A, self.img_res)
                img_B = scipy.misc.imresize(img_B, self.img_res)

                if not is_testing and np.random.random() > 0.5:
                    img_A = np.fliplr(img_A)
                    img_B = np.fliplr(img_B)

                imgs_A.append(img_A)
                imgs_B.append(img_B)

            imgs_A = np.array(imgs_A) / 127.5 - 1.
            imgs_B = np.array(imgs_B) / 127.5 - 1.

            yield imgs_A, imgs_B

    def load_img(self, path):
        img = self.imread(path)
        img = scipy.misc.imresize(img, self.img_res)
        img = img / 127.5 - 1.
        return img[np.newaxis, :, :, :]

    def imread(self, path):
        return scipy.misc.imread(path, mode='RGB').astype(np.float)


In [0]:
import datetime
import os
from PIL import Image, ImageDraw

import matplotlib.pyplot as plt
import numpy as np
from keras.layers import Input, Dropout, Concatenate, Reshape
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D, Conv1D
from keras.models import Model
from keras.optimizers import Adam
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers import Dense, Reshape, Flatten, Lambda, Multiply
from keras import backend as K
from keras.backend import tf as ktf

import cv2
text_file = open("/content/file6.txt", "w")

# from data_loader import DataLoader


class CycleGAN:
    def __init__(self):
        # Input shape
        self.img_rows = 128
        self.img_cols = 128
        self.channels = 3
        self.img_shape = (self.img_rows, self.img_cols, self.channels)

        # Configure data loader
        self.dataset_name = 'flower_images'
        self.data_loader = DataLoader(dataset_name=self.dataset_name,
                                      img_res=(self.img_rows, self.img_cols))

        # Calculate output shape of D (PatchGAN)
        patch = int(self.img_rows / 2 ** 4)
        self.disc_patch = (patch, patch, 1)

        # Number of filters in the first layer of G and D
        self.gf = 32
        self.df = 64

        # Loss weights
        self.lambda_cycle = 10.0  # Cycle-consistency loss
        self.lambda_id = 0.1 * self.lambda_cycle  # Identity loss

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminators
        self.d_A = self.build_discriminator()
        self.d_B = self.build_discriminator()
        self.d_A.compile(loss='mse', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])
        self.d_B.compile(loss='mse', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])

        # -------------------------
        # Construct Computational
        #   Graph of Generators
        # -------------------------

        # Build the generators
        self.g_AB = self.build_generator()
        self.g_BA = self.build_generator()

        # Input images from both domains
        img_A = Input(shape=self.img_shape)
        img_B = Input(shape=self.img_shape)
        orig = Input(shape=self.img_shape)

        # Translate images to the other domain
        fake_B = self.g_AB(img_A)
        fake_A = self.g_BA(img_B)
        # Translate images back to original domain
        reconstr_A = self.g_BA(fake_B)
        reconstr_B = self.g_AB(fake_A)
        # Identity mapping of images
        img_A_id = self.g_BA(img_A)
        img_B_id = self.g_AB(img_B)

        # For the combined model we will only train the generators
        self.d_A.trainable = False
        self.d_B.trainable = False

        # Discriminators determines validity of translated images
        valid_A = self.d_A(fake_A)
        valid_B = self.d_B(fake_B)

        # Combined model trains generators to fool discriminators
        self.combined = Model(inputs=[img_A, img_B],
                              outputs=[valid_A, valid_B,
                                       reconstr_A, reconstr_B,
                                       img_A_id, img_B_id])
        self.combined.compile(loss=['mse', 'mse',
                                    'mae', 'mae',
                                    'mae', 'mae'],
                              loss_weights=[1, 1,
                                            self.lambda_cycle, self.lambda_cycle,
                                            self.lambda_id, self.lambda_id],
                              optimizer=optimizer)

    def build_generator(self):
        """U-Net Generator"""

        def conv2d(layer_input, filters, f_size=4):
            """Layers used during downsampling"""
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            d = InstanceNormalization()(d)
            return d

        def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
            """Layers used during upsampling"""
            u = UpSampling2D(size=2)(layer_input)
            u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u)
            if dropout_rate:
                u = Dropout(dropout_rate)(u)
            u = InstanceNormalization()(u)
            u = Concatenate()([u, skip_input])
            return u

        # Image input
        d0 = Input(shape=self.img_shape)

        # Downsampling
        d1 = conv2d(d0, self.gf)
        d2 = conv2d(d1, self.gf * 2)
        d3 = conv2d(d2, self.gf * 4)
        d4 = conv2d(d3, self.gf * 8)

        # Upsampling
        u1 = deconv2d(d4, d3, self.gf * 4)
        u2 = deconv2d(u1, d2, self.gf * 2)
        u3 = deconv2d(u2, d1, self.gf)

        u4 = UpSampling2D(size=2)(u3)
        output_img = Conv2D(self.channels, kernel_size=4, strides=1, padding='same', activation='tanh')(u4)

        return Model(d0, output_img)

    def build_discriminator(self):

        def squash(vectors, axis=-1):
            s_squared_norm = K.sum(K.square(vectors), axis, keepdims=True)
            scale = s_squared_norm / (1 + s_squared_norm) / K.sqrt(s_squared_norm + K.epsilon())
            return scale * vectors


        img = Input(shape=self.img_shape)
        try:
            x = Lambda(lambda image: ktf.image.resize_images(image, (32, 32)))(img)
        except :
            # if you have older version of tensorflow
            x = Lambda(lambda image: ktf.image.resize_images(image, 32, 32))(img)
  
        x = Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', name='conv1')(x)
        x = LeakyReLU()(x)
        x = BatchNormalization(momentum=0.8)(x)
        x = Conv2D(filters=8 * 32, kernel_size=9, strides=2, padding='valid', name='primarycap_conv2')(x)
        x = Reshape(target_shape=[-1, 8], name='primarycap_reshape')(x)
        x = Lambda(squash, name='primarycap_squash')(x)
        x = BatchNormalization(momentum=0.8)(x)
        x = Flatten()(x)

        uhat = Dense(160, kernel_initializer='he_normal', bias_initializer='zeros', name='uhat_digitcaps')(x)
        c = Activation('softmax', name='softmax_digitcaps1')(
            uhat)
        c = Dense(160)(c)
        x = Multiply()([uhat, c])
        s_j = LeakyReLU()(x)

        c = Activation('softmax', name='softmax_digitcaps2')(
            s_j)
        c = Dense(160)(c)
        x = Multiply()([uhat, c])
        s_j = LeakyReLU()(x)

        c = Activation('softmax', name='softmax_digitcaps3')(
            s_j)
        c = Dense(160)(c)
        x = Multiply()([uhat, c])
        s_j = LeakyReLU()(x)
        
        pred = Dense(1, activation='sigmoid')(s_j)

        return Model(img, pred)


    def train(self, epochs, batch_size=1, sample_interval=50):

        start_time = datetime.datetime.now()

        # Adversarial loss ground truths
        fake = np.zeros((batch_size, 1))
        valid = np.ones((batch_size, 1))
                

        for epoch in range(epochs):
            for batch_i, (imgs_A, imgs_B) in enumerate(self.data_loader.load_batch(batch_size)):

                # ----------------------
                #  Train Discriminators
                # ----------------------
                
                valid = np.ones((batch_size, 1))

                # Translate images to opposite domain
                fake_B = self.g_AB.predict(imgs_A)
                fake_A = self.g_BA.predict(imgs_B)

                # Train the discriminators (original images = real / translated = Fake)
                dA_loss_real = self.d_A.train_on_batch(imgs_A, valid)
                dA_loss_fake = self.d_A.train_on_batch(fake_A, fake)
            
                dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)

                dB_loss_real = self.d_B.train_on_batch(imgs_B, valid)
                dB_loss_fake = self.d_B.train_on_batch(fake_B, fake)
                
                dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

                # Total disciminator loss
                d_loss = 0.5 * np.add(dA_loss, dB_loss)

                # ------------------
                #  Train Generators
                # ------------------

                # Train the generators
                
                valid = np.ones((batch_size, 1))
                
                g_loss = self.combined.train_on_batch([imgs_A, imgs_B],
                                                      [valid, valid,
                                                       imgs_A, imgs_B,
                                                       imgs_A, imgs_B])

                elapsed_time = datetime.datetime.now() - start_time

                # Plot the progress
                if batch_i % sample_interval == 0:
                  print(
                      "[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, adv: %05f, recon: %05f, "
                      "id: %05f] time: %s " \
                      % (epoch, epochs,
                         batch_i, self.data_loader.n_batches,
                         d_loss[0], 100 * d_loss[1],
                         g_loss[0],
                         np.mean(g_loss[1:3]),
                         np.mean(g_loss[3:5]),
                         np.mean(g_loss[5:6]),
                         elapsed_time))

                # If at save interval => save generated image samples
                if batch_i % 10 == 0:
                    self.sample_images(epoch, batch_i)

    def sample_images(self, epoch, batch_i):
        os.makedirs('images2/%s' % self.dataset_name, exist_ok=True)
        os.makedirs('images3/%s' % self.dataset_name, exist_ok=True)
        os.makedirs('images4/%s' % self.dataset_name, exist_ok=True)
        os.makedirs('images5/%s' % self.dataset_name, exist_ok=True)

        r, c = 2, 3

        imgs_A, orig = self.data_loader.load_data(domain="A", batch_size=1, is_testing=False)
        imgs_A = np.expand_dims(imgs_A, axis=0)
        imgs_B = self.data_loader.load_data(domain="B", batch_size=1, is_testing=False)


        # Translate images to the other domain
        fake_B = self.g_AB.predict(imgs_A)
        fake_A = self.g_BA.predict(imgs_B)
        # Translate back to original domain
        reconstr_A = self.g_BA.predict(fake_B)
        reconstr_B = self.g_AB.predict(fake_A)

        gen_imgs = np.concatenate([imgs_A, fake_B, reconstr_A, imgs_B, fake_A, reconstr_B])
        orig = 0.5 * orig + 0.5

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        titles = ['Original', 'Translated', 'Reconstructed']
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i, j].imshow(gen_imgs[cnt])
                axs[i, j].set_title(titles[j])
                axs[i, j].axis('off')
                cnt += 1
        fig.savefig("images2/%s/%d_%d.png" % (self.dataset_name, epoch, batch_i))
        plt.close()
        
        fig, axs = plt.subplots()
        axs.imshow(gen_imgs[1])
        axs.axis('off')
        fig.savefig("images3/%s/%d_%d.png" % (self.dataset_name, epoch, batch_i))
        plt.close()

        fig, axs = plt.subplots()
        axs.imshow(orig)
        axs.axis('off')
        fig.savefig("images4/%s/%d_%d.png" % (self.dataset_name, epoch, batch_i))
        plt.close()
        
        fig, axs = plt.subplots()
        axs.imshow(gen_imgs[0])
        axs.axis('off')
        fig.savefig("images5/%s/%d_%d.png" % (self.dataset_name, epoch, batch_i))
        plt.close()
        orig = np.expand_dims(orig, axis=0)
        text_file.write("images5/%s/%d_%d.png" % (self.dataset_name, epoch, batch_i))
        text_file.write("\n")
        text_file.write("Losses: %d" % (np.sum(np.absolute(np.subtract(orig,fake_B)))))
        text_file.write("\n")

        print("Losses per image: "+ str(np.sum(np.absolute(np.subtract(orig,fake_B)))))

if __name__ == '__main__':
    gan = CycleGAN()
    gan.train(epochs=200, batch_size=32, sample_interval=10)

In [0]:
text_file.close()

!zip -r /content/file2.zip /content/images2
!zip -r /content/file3.zip /content/images3
!zip -r /content/file4.zip /content/images4
!zip -r /content/file5.zip /content/images5

from google.colab import files
files.download("/content/file2.zip")
files.download("/content/file3.zip")
files.download("/content/file4.zip")
files.download("/content/file5.zip")
files.download("/content/file6.txt")