# Implementation

## Header

### Importing modules

In [None]:
#--Keras
from keras import backend as K
from keras import datasets, layers, models, optimizers, utils, initializers
from keras import callbacks as cbks
from keras.optimizers import Adam, SGD, RMSprop
from keras.utils import multi_gpu_model
from keras.utils.vis_utils import plot_model
from keras.preprocessing.image import ImageDataGenerator
#--Tensorflow
import tensorflow as tf
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.client import device_lib
#dataset modules
from datasets.smallNORB import smallNORB #internal dependency
#--Misc
import os, sys
from datetime import datetime
import numpy as np
import scipy
import cv2
from sklearn.feature_extraction.image import extract_patches_2d
from tqdm import tqdm
from __future__ import division
from matplotlib import pyplot as plt, gridspec
import matplotlib
matplotlib.use('Agg')
%matplotlib inline
import functools
from functools import partial
import math

#--CPU + GPU use
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0" #"0,1,2" for "cpu, gpu0, gpu1"
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.4) #40% of GPU memory
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) #set GPU

#--List of available devices
print('Devices:', device_lib.list_local_devices())
#--Number of available devices
n_devices= len(device_lib.list_local_devices())

#--Fixing random state for reproducibility
np.random.seed(10) 

#plt.ioff()

### Parameters

In [None]:
#--Definition of the model to train - Setting of the hyperparameters
DATASET={"name": "CIFAR10", #MNIST , CIFAR10, smallNORB
         "train": None, #Training dataset contrainer
         "test": None, #Validation dataset contrainer
         "set": "test",
         "param":{"target_scale": [-1, 1],
                  "target_shape": None, #None: will be automatically detected, (for mnist 28x28x1 cifar10 32x32x3 smallNORB 48x48x1)
                  # do not combine extracting patches and resize
                  "target_patch": None,# extract paches (mnist: None, cifar: 28,28 , smallnorb: 48,48) [H,W,number of patches]
                  "target_size": None, # resize
                  "one_hot": True # code the labels (if used)
                 }
        }

GENERATOR= {"name":"generator",
            "train": None, #Generator's model contrainer
            "param":{"topology":"ConvNet",
                     "inputs_shape": (100,), #noise dim
                     "output_shape": None, #used later
                     "DeConvNet":{
                         "optimizer": Adam(lr=0.0002, beta_1=0.5),
                         "iters": 1, # number of training iterations before/after training the discriminator
                     }
            }
}
DISCRIMINATOR={"name":"discriminator",
            "train": None, #Discriminator's model contrainer
            "param":{"topology":"ConvNet", #possible values: "ConvNet", "Critic", "VCapsNet", "MCapsNet"
                     "inputs_shape":None, # will be defined after calling dataset 
                     "output_shape": None, # will be defined after calling dataset
                     
                     "ConvNet":{"decoder": None, # always None to be consistent with vcapsnet topology
                                "optimizer": Adam(lr=0.0002, beta_1=0.5), 
                                "iters": 1, # number of training iterations before/after training the generator
                               },
                     
                     "Critic":{"decoder": None, # always None to be consistent with vcapsnet topology
                                "optimizer": RMSprop(lr=5e-5),
                                "iters": 5, # number of training iterations before/after training the generator
                               },
                     
                     "VCapsNet":{
                                "routing_iters":3,
                                "decoder":False, # turn on the decoder
                                "decoder_factor": 0.0005,  #0.392
                                "L1_n": 256,
                                "L2_n": 32, #32 for mnist, 64 for cifar
                                "L2_dim": 8,
                                "L3_dim": 16,
                                "L4_n": 512,
                                "L5_n": 1024,
                                "optimizer": Adam(lr=0.0002, beta_1=0.5),
                                "iters": 1, # number of training iterations before/after training the generator
                               },
                     
                     "MCapsNet":{"routing_iters":3,
                                       "decoder": False,  # always None to be consistent with vcapsnet topology
                                       "L1_n": 8,
                                       "L2_n": 8, #8
                                       "L3_n": 8, #16
                                       "L4_n": 8, #16
                                       "pose_shape": [4,4],
                                       "optimizer": Adam(lr=0.0002, beta_1=0.5),
                                       "iters": 1, # number of training iterations before/after training the generator
                                      },
                    }
           }

COMBINED={"name":"gan",
          "train":None, #GAN's model contrainer
          "param":{"topology": "VCapsGAN", # DCGAN, WGAN_GP, VCapsGAN, MCapsGAN
                  },
         }

TRAIN={
    #To train the discriminator only (as classifier), comment the line GENERATOR['name']:GENERATOR['train']
    "models_to_train": {DISCRIMINATOR['name'] : DISCRIMINATOR['train'],
                           GENERATOR['name']:GENERATOR['train'], 
                          },
    # load weights from checkpoints
    "trained_models":{'G':None, #'./ConvNet_GAN_CIFAR10_09-12_00-18/models/gen_ConvNet_Up.h5',
                         'D':None, #'./ConvNet_GAN_CIFAR10_09-12_00-18/models/disc_VCapsNet.h5',
                        },
    # activate histograms in tensorboard
    "debug": False,
    "param":{"batch_size": 100, # the number of samples must be divisible by the batch size , mnist 60k cifar10 50k smallNORB 24.3k
             "epochs": 100,
             "train_samples": None, # will be defined after calling dataset
             "checkpoint": {"interval":None,
                            "logdir": "./",
                            "models":{"save": True,
                                     },
                           }
               }
      }

# Update the checkpoint log directory
TRAIN["param"]["checkpoint"]["logdir"] += DISCRIMINATOR["param"]["topology"]+"_"
if len(TRAIN["models_to_train"]) == 2:
    TRAIN["param"]["checkpoint"]["logdir"] += "GAN_"
else:
    TRAIN["param"]["checkpoint"]["logdir"] += "classifier_"
TRAIN["param"]["checkpoint"]["logdir"] += DATASET["name"] +"_"+ datetime.now().strftime('%m-%d_%H-%M')+"/"


### Dataset Loader

In [None]:
class DataLoader ():
    def __init__(self, set, name, target_shape=None, target_scale=None, target_patch=None, target_size=None, one_hot=False):
        # Initialize attributes
        self.name= name
        self.target_shape= target_shape
        self.target_scale= target_scale
        self.target_patch= target_patch
        self.target_size= target_size
        self.one_hot = one_hot
        self.set = set
        # Load dataset
        if self.name == 'MNIST':  self.load_MNIST()
        if self.name == 'CIFAR10': self.load_CIFAR10()
        if self.name == 'smallNORB': self.load_smallNORB()
        self.imgs = np.array(self.imgs)
        self.labels = np.array(self.labels)
        self.n_samples=self.imgs.shape[0]
        # Encode labels
        if self.one_hot: self.labels=utils.to_categorical(self.labels, self.num_classes) 
        # Reshape dataset
        if len(self.imgs.shape) == 3:
            self.reshape((self.imgs.shape[1], self.imgs.shape[2], 1))
        if self.target_shape:
            self.reshape(self.target_shape)
        self.img_shape=(self.imgs.shape[1], self.imgs.shape[2], self.imgs.shape[3])
        # Scale values of images
        if self.target_scale: self.rescale(self.target_scale)        
        
        print("{} {} dataset has been uploaded successfully \n {} samples - shape: {}".format(self.name, self.set, self.n_samples, self.img_shape))
        
    def reshape(self, target_shape):
        self.imgs = self.imgs.reshape((self.n_samples,) + target_shape)
        
    def rescale(self, target_scale):
        self.imgs = self.imgs.astype('float32') / (255/(target_scale[1]-target_scale[0])) + target_scale[0] #imgs in [rescale[0],rescale[1]]
    
    def patch (self, target_patch):
        print("Extracting patches ...")
        imgs = [] ; labels = [];
        for i in range(len(self.imgs)):
            if len(self.imgs[0].shape) == 3:
                if self.imgs[0].shape[2] == 1:
                    x = extract_patches_2d(self.imgs[i, :, :, 0], (target_patch[0], target_patch[1]), max_patches=target_patch[2])
                else:
                    x = extract_patches_2d(self.imgs[i, :, :, :], (target_patch[0], target_patch[1]), max_patches=target_patch[2])
            if len(self.imgs[0].shape) == 2:
                x = extract_patches_2d(self.imgs[i], (target_patch[0], target_patch[1]), max_patches=target_patch[2])
            y = np.full(target_patch[2], self.labels[i], dtype=int)
            imgs.extend(x)
            labels.extend(y)
        print("Extraction finished!")
        return np.array(imgs), np.array(labels)
    
    def resize (self, target_size):
        self.imgs = [cv2.resize(x, dsize=(target_size[0],target_size[1])) for x in self.imgs]
    
    def load_MNIST(self):
        if self.set is 'train': (self.imgs, self.labels), (_,_) = datasets.mnist.load_data() # (imgs, labels): (60000x28x28 in [0,255], 60000x1 in [0,9]) (not onehot coded)
        if self.set is 'test':  (_,_), (self.imgs, self.labels) = datasets.mnist.load_data() # (imgs, labels): (60000x28x28 in [0,255], 60000x1 in [0,9]) (not onehot coded)
        self.num_classes= len(np.unique(self.labels, axis=0))
        if self.target_size : self.resize(self.target_size)
        if self.target_patch : self.imgs, self.labels = self.patch(self.target_patch)
            
    def load_CIFAR10(self):
        if self.set is 'train': (self.imgs, self.labels), (_,_) = datasets.cifar10.load_data() # (imgs, labels): (50000x32x32x3 in [0,255], 50000x1 in [0,9]) (not onehot coded)
        if self.set is 'test':  (_,_), (self.imgs, self.labels) = datasets.cifar10.load_data() # (imgs, labels): (50000x32x32x3 in [0,255], 50000x1 in [0,9]) (not onehot coded)
        self.num_classes= len(np.unique(self.labels, axis=0))
        if self.target_size : self.resize(self.target_size)
        if self.target_patch : self.imgs, self.labels = self.patch(self.target_patch)
            
    def load_smallNORB(self):
        if self.target_patch:
            imgs_file_name = 'smallNorb_'+self.set+'_imgs_patches_'+str(self.target_patch[0])+'_'+str(self.target_patch[1])+'.npy'
            labels_file_name = 'smallNorb_'+self.set+'_labels_patches_'+str(self.target_patch[0])+'_'+str(self.target_patch[1])+'.npy'
        elif self.target_size:
            imgs_file_name = 'smallNorb_'+self.set+'_imgs_resized_'+str(self.target_size[0])+'_'+str(self.target_size[1])+'.npy'
            labels_file_name = 'smallNorb_'+self.set+'_labels_resized_'+str(self.target_size[0])+'_'+str(self.target_size[1])+'.npy'
        else :
            imgs_file_name = 'smallNorb_'+self.set+'_imgs.npy'
            labels_file_name = 'smallNorb_'+self.set+'_labels.npy'
        if not os.path.exists(os.path.join('datasets', imgs_file_name)) or not os.path.exists(os.path.join('datasets', labels_file_name)):
            if not os.path.exists(os.path.join('datasets', 'smallNorb_'+self.set+'_imgs.npy')) or not os.path.exists(os.path.join('datasets', 'smallNorb_'+self.set+'_labels.npy')):
                (self.imgs, _ , self.labels, _) = smallNORB(dataset_dir='datasets', set=self.set).load_data() # (imgs, labels): (60000x32x32x3 in [0,255], 60000x1 in [0,9]) (not onehot coded)
                np.save(os.path.join('datasets', 'smallNorb_'+self.set+'_imgs.npy'), self.imgs)
                np.save(os.path.join('datasets', 'smallNorb_'+self.set+'_labels.npy'), self.labels)
            else:
                self.imgs = np.load (os.path.join('datasets','smallNorb_'+self.set+'_imgs.npy'))
                self.labels = np.load (os.path.join('datasets','smallNorb_'+self.set+'_labels.npy'))
            if self.target_size : self.resize(self.target_size)
            if self.target_patch: self.imgs , self.labels = self.patch(self.target_patch)
            if not os.path.exists(os.path.join('datasets', imgs_file_name)):
                np.save(os.path.join('datasets', imgs_file_name), self.imgs)
            else:
                self.imgs = np.load (os.path.join('datasets',imgs_file_name))
            if not os.path.exists(os.path.join('datasets', labels_file_name)):
                np.save(os.path.join('datasets', labels_file_name), self.labels)
            else:
                self.labels = np.load (os.path.join('datasets',labels_file_name))
        else:
            self.imgs = np.load (os.path.join('datasets',imgs_file_name))
            self.labels = np.load (os.path.join('datasets',labels_file_name))
        self.num_classes= len(np.unique(self.labels, axis=0))
        
    '''def next_batch(self, batch_size):
        idx = np.arange(0 , len(self.imgs))
        np.random.shuffle(idx)
        idx = idx[:batch_size]
        data_shuffle = [self.imgs[ i] for i in idx]
        labels_shuffle = [self.labels[ i] for i in idx]
        return np.asarray(data_shuffle), np.asarray(labels_shuffle)'''
    
    def plot_samples(self, grid=[10,10], imgsize=[8,8], logdir=None):
        imgs=self.imgs
        img_range=self.target_scale
        cmap=(None if (self.imgs.shape[-1]) == 3 else 'gray')
        # Get random samples
        imgs= imgs[0:grid[0]*grid[1]]
        if cmap is 'gray': imgs = np.squeeze(imgs, -1)
        if np.shape(imgs)[0] is 1: imgs = np.squeeze(imgs, 0)
        imgs = ((imgs-img_range[0])*255/(img_range[1]-img_range[0])).astype(np.uint8)
        # Create a figure object
        fig= plt.figure(figsize=(imgsize[0], imgsize[1]))
        # Show images
        for i in range(0, grid[0]*grid[1]):        
            fig.add_subplot(grid[0], grid[1], i+1)
            img = imgs[i]
            plt.imshow(img, cmap=cmap)
            plt.axis('off')
        plt.show()
        if logdir: fig.savefig(os.path.join(logdir, self.name+'.png'))
            
    def inception_score(self, splits=10):
        imgs = self.imgs
        if (len(np.shape(imgs))==3 or np.shape(imgs)[-1]==1):
            imgs = np.squeeze(imgs)
            imgs = np.stack((imgs,)*3, -1)
        imgs = np.rollaxis(imgs, 3, 1)  
        session = tf.InteractiveSession()
        BATCH_SIZE=64
        # Run images through Inception.
        inception_images=tf.placeholder(tf.float32,[BATCH_SIZE,3,None,None])
        def inception_logits(images=inception_images, num_splits=1):
            images=tf.transpose(images,[0,2,3,1])
            size = 299
            images = tf.image.resize_bilinear(images, [size, size])
            generated_images_list = array_ops.split(
            images, num_or_size_splits=num_splits)
            logits = functional_ops.map_fn(
                fn=functools.partial(tf.contrib.gan.eval.run_inception, output_tensor='logits:0'),
                elems=array_ops.stack(generated_images_list),
                parallel_iterations=1,
                back_prop=False,
                swap_memory=True,
                name='RunClassifier')
            logits = array_ops.concat(array_ops.unstack(logits), 0)
            return logits

        logits=inception_logits()

        def get_inception_probs(inps):
            preds = []
            n_batches = len(inps)//BATCH_SIZE
            for i in range(n_batches):
                inp = inps[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
                pred = logits.eval({inception_images:inp})[:,:1000]
                preds.append(pred)
            preds = np.concatenate(preds, 0)
            preds=np.exp(preds)/np.sum(np.exp(preds),1,keepdims=True)
            return preds

        def preds2score(preds,splits):
            scores = []
            for i in range(splits):
                part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :]
                kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
                kl = np.mean(np.sum(kl, 1))
                scores.append(np.exp(kl))
            return np.mean(scores), np.std(scores)

        def get_inception_score(images, splits):
            assert(type(images) == np.ndarray)
            assert(len(images.shape)==4)
            assert(images.shape[1]==3)
            assert(np.max(images[0])<=1)
            assert(np.min(images[0])>=-1)
            preds=get_inception_probs(images)
            print ('Inception Score for %i samples in %i splits'% (preds.shape[0],splits))
            mean,std = preds2score(preds,splits)
            return mean,std  # Reference values: 11.34 for 49984 CIFAR-10 training set images, or mean=11.31, std=0.08 if in 10 splits (default).
        score = get_inception_score(imgs, splits)
        tf.InteractiveSession.close(session)
        return score

In [None]:
# Get dataset
DATASET["train"] = DataLoader('train', DATASET["name"], DATASET["param"]["target_shape"], DATASET["param"]["target_scale"], DATASET["param"]["target_patch"], DATASET["param"]["target_size"], DATASET["param"]["one_hot"])
DATASET["test"] = DataLoader('test', DATASET["name"], DATASET["param"]["target_shape"], DATASET["param"]["target_scale"], DATASET["param"]["target_patch"], DATASET["param"]["target_size"], DATASET["param"]["one_hot"])
# Update the containers defined in the dictionary
TRAIN["param"]["train_samples"]= DATASET["train"].n_samples
if DATASET["param"]["target_scale"] is None: DATASET["param"]["target_scale"]= [0, 255]
if DISCRIMINATOR["param"]["inputs_shape"] is None: DISCRIMINATOR["param"]["inputs_shape"]= DATASET["train"].img_shape
if DISCRIMINATOR["param"]["output_shape"] is None:
    if len(TRAIN["models_to_train"]) == 2:
        DISCRIMINATOR["param"]["output_shape"]= (2,)
    else:
        DISCRIMINATOR["param"]["output_shape"]= (DATASET["train"].num_classes,)

if GENERATOR["param"]["output_shape"] is None: GENERATOR["param"]["output_shape"]= DATASET["train"].img_shape

# Create checkpoint log directory
if not os.path.exists(TRAIN["param"]["checkpoint"]["logdir"]): os.makedirs(TRAIN["param"]["checkpoint"]["logdir"])

# Plot random samples
DATASET["train"].plot_samples(grid=[10,10], imgsize=[10,10], logdir=TRAIN["param"]["checkpoint"]["logdir"])
# Compute the Inception score of the dataset
dataset_inception_score = DATASET["train"].inception_score()
print("IS: mean {}, stdv {}".format(dataset_inception_score[0], dataset_inception_score[1]))
assert DATASET["train"].n_samples % TRAIN["param"]["batch_size"] == 0, "number of samples must be divisible by the batch size"

### Helper functions

#### Visualization functions

In [None]:
if len(TRAIN["models_to_train"]) == 2: # if training a GAN
    def generator_sampler(latent_dim, generator, n_images):
        def sampler():
            gen_dict = GENERATOR["param"]
            zsamples = np.random.normal(size=(n_images[0]*n_images[1], latent_dim))
            gen = imgs_utils.dim_ordering_unfix(generator.predict(zsamples)).transpose((0, 2, 3, 1))
            if len (gen_dict["output_shape"]) == 3:
                if gen_dict["output_shape"][2] == 1:
                    img_shape=gen_dict["output_shape"][:2]
                else:
                    img_shape= gen_dict["output_shape"]
            else:
                img_shape=gen_dict["output_shape"]
                
            return gen.reshape((n_images[0],n_images[1],)+ img_shape)
        return sampler

#### Squash function

$$squash(\mathbf{s}) = \frac{\| \mathbf{s} \|^2}{1+ \| \mathbf{s} \|^2} \frac{\mathbf{s}}{\| \mathbf{s} \|}$$

In [None]:
def squash(inputs, axis=-1):
    s_squared_norm = K.sum(K.square(inputs), axis, keepdims=True)
    scale = (s_squared_norm / (1 + s_squared_norm))
    return scale * inputs / K.sqrt(s_squared_norm + K.epsilon())

#### Margin loss

$$ \mathcal{L} = \sum_i y_{true, i} \max(0, m^+ - y_{pred, i})^2 + \lambda (1-y_{true, i}) \max (0, y_{pred, i} - m^-)^2 $$
where $\mathbf{y}_{pred, i} = \| \mathbf{v}_i \| \, , \, m^+ = 0.9 \, , \, m^- = 0.1 \, \text{and} \, \lambda = 0.5$

#### Spread loss

$$\mathcal{L}=  \sum_i max(0, margin - (y_{pred\_true, i}-y_{pred\_false, i}))^2$$

#### Crossentropy loss

The generator tries to minimize the cross entropy (information distance) between the prediction of Discriminator given a generated sample and a valid prediction $\Rightarrow$ goal: fool the Discriminator  <br> 
\begin{align} G =&\underset{G}{\operatorname{argmin}} \mathcal{H}(y_{true} = 1, y_{pred} = D(G(\textbf{z}^{(i)})))\\=&\underset{G}{\operatorname{argmax}} \mathcal{H}(y_{true} = 0, y_{pred} = D(G(\textbf{z}^{(i)})))\\=&\underset{G}{\operatorname{argmax}} \mathop{\mathbb{E}}_{\textbf{z}^{(i)} \sim p_z(\textbf{z}^{(i)})} [-log(1-D(G(\textbf{z}^{(i)}))] \\=&\underset{G}{\operatorname{argmin}} \mathop{\mathbb{E}}_{\textbf{z}^{(i)} \sim p_z(\textbf{z}^{(i)})} [log(1-D(G(\textbf{z}^{(i)}))] \\=&\underset{G}{\operatorname{argmin}}  \sum_{i}log(1- D(G(\textbf{z}^{(i)}))) \end{align}

The margin loss $L_M$ is an empirical loss function that worked very well in the CapsNet und is also introduced in the CapsGAN: <br>

\begin{align} \mathcal{L}_M(y_{true}, y_{pred}) = \sum_{k=1}^K y_{true} max (0,m^+-y_{pred})^2+\lambda(1-y_{true})max(0,y_{pred}-m^-)^2\end{align}<br><br>
$K$: number  of classes (here $K$=2)

$L_M$ can be also interpreted as a probability density distance measurement, thus:<br>
\begin{align}G=&\underset{G}{\operatorname{argmin}} \mathop{\mathbb{E}}_{\textbf{z}^{(i)} \sim p_z(\textbf{z}^{(i)})} [\mathcal{L}_M(y_{true}=1,y_{pred}=D(G(\textbf{z})))]\\ =&\underset{G}{\operatorname{argmin}}  \sum_i  max (0,m^+-D(G(\textbf{z}^{(i)})))^2\end{align}

The discriminator tries to minimize the cross entropy between its prediction given a generated sample and a fake prediction. 	
At the same time, it tries to minimize the cross entropy between its prediction given a real sample and a valid prediction.<br>
$\Rightarrow$ Improve the classification

\begin{align}D =&\underset{D}{\operatorname{argmin}} \mathcal{H}(y_{true} = [1,0], y_{pred} = [D(\textbf{x}^{(i)}), D(G(\textbf{z}^{(i)}))]) \\=&\underset{D}{\operatorname{argmin}}  - \mathop{\mathbb{E}}_{\textbf{x}^{(i)} \sim p_{data}(\textbf{x}^{(i)})} [log(D(\textbf{x}^{(i)})) ] - \mathop{\mathbb{E}}_{\textbf{z}^{(i)} \sim p_z(\textbf{z}^{(i)})} [log (1-D(G(\textbf{z}^{(i)}))]\\=&\underset{D}{\operatorname{argmax}}\mathop{\mathbb{E}}_{\textbf{x}^{(i)} \sim p_{data}(\textbf{x}^{(i)})} [log(D(\textbf{x}^{(i)})) ] + \mathop{\mathbb{E}}_{\textbf{z}^{(i)} \sim p_z(\textbf{z}^{(i)})} [log (1-D(G(\textbf{z}^{(i)}))] \\=&\underset{D}{\operatorname{argmax}} \sum_{i}log (D(\textbf{x}^{(i)})) + log(1-D(G(\textbf{z}^{(i)})))\end{align}
<br>
The margin loss $\mathcal{L}_M$ is an empirical loss function that worked very well in the CapsNet und is also introduced in the CapsGAN: <br>

\begin{align} \mathcal{L}_M(y_{true}, y_{pred}) = \sum_{k=1}^K y_{true} max (0,m^+-y_{pred})^2+\lambda(1-y_{true})max(0,y_{pred}-m^-)^2\end{align}<br><br>
$K$: number  of classes (here $K$=2)

$L_M$ can be also interpreted as a probability density distance measurement, thus:<br>
\begin{align}D=&\underset{D}{\operatorname{argmin}} \mathop{\mathbb{E}}_{\textbf{z}^{(i)} \sim p_z(\textbf{z}^{(i)})} [\mathcal{L}_M(y_{true}= [1,0] , y_{pred}=[D(\textbf{x}^{(i)}),D(G(\textbf{z}^{(i)}))])] \\=&\underset{G}{\operatorname{argmin}}  \sum_i max (0,m^+-D(\textbf{x}^{(i)}))^2 + \lambda max (0,D(G(\textbf{z}^{(i)})-m^-)^2\end{align}

#### Gradient penalty loss

In [None]:
def gradient_penalty_loss(self, y_true, y_pred, averaged_samples):
        gradients = K.gradients(y_pred, averaged_samples)[0]
        gradients_sqr = K.square(gradients)
        gradients_sqr_sum = K.sum(gradients_sqr,
                                  axis=np.arange(1, len(gradients_sqr.shape)))
        gradient_l2_norm = K.sqrt(gradients_sqr_sum)
        gradient_penalty = K.square(1 - gradient_l2_norm)
        return K.mean(gradient_penalty)

#### Wasserstein loss

In [None]:
def wasserstein_loss(self, y_true, y_pred):
        return K.mean(y_true * y_pred)

## Defining the models

### Generator

In [None]:
class Generator:
    def __init__(self, name='generator', **kwargs):
        self.name = name
        self.model = None
        self.input_shape = None
        self.output_shape = None
        
    def build(self, input_shape, output_shape, len_io):
        self.input_shape = input_shape
        self.output_shape = output_shape
        self.sequential = self.build_sequential(input_shape, output_shape)
        self.sequential.name = 'generator'
        
        inputs = []
        outputs = []
        for i in range(len_io):
            inputs += [layers.Input(shape=input_shape, name='img'+str(i+1)),]
            outputs += [self.sequential(inputs[i]),]
        
        self.model = models.Model(inputs=inputs,
                            outputs=outputs)
        self.model.name = 'generator_in_gan'
        print("************************************GENERATOR*************************************")
        self.sequential.summary()
        return self.model
    
    def compile(self, **kwargs):
        self.model.compile(**kwargs)
        print(self.name, "compiled")

In [None]:
class ConvNet_Up(Generator):
    def __init__ (self, name='ConvNet_Up', **kwargs):
        super(ConvNet_Up, self).__init__(name=name, **kwargs)
        
    def build_sequential(self, input_shape, output_shape):

        # input layer
        inputs=layers.Input(shape=input_shape)
        x = layers.Dense(256 * output_shape[0]//4 * output_shape[1]//4, use_bias=False, kernel_initializer=initializers.RandomNormal(0, 0.02))(inputs)
        x = layers.Reshape((output_shape[0]//4 , output_shape[1]//4 , 256))(x)
        x = layers.BatchNormalization(momentum=0.9, gamma_initializer=initializers.RandomNormal(1, 0.02))(x, training=1)
        x = layers.Activation("relu")(x)
        x = layers.UpSampling2D()(x)
        x = layers.Conv2D(128, kernel_size=4, padding="same", use_bias=False, kernel_initializer=initializers.RandomNormal(0, 0.02))(x)
        x = layers.BatchNormalization(momentum=0.9, gamma_initializer=initializers.RandomNormal(1, 0.02))(x, training=1)
        x = layers.Activation("relu")(x)
        x = layers.UpSampling2D()(x)
        x = layers.Conv2D(64, kernel_size=4, padding="same", use_bias=False, kernel_initializer=initializers.RandomNormal(0, 0.02))(x)
        x = layers.BatchNormalization(momentum=0.9, gamma_initializer=initializers.RandomNormal(1, 0.02))(x, training=1)
        x = layers.Activation("relu")(x)
        x = layers.Conv2D(output_shape[-1], kernel_size=4, padding="same", use_bias=False, kernel_initializer=initializers.RandomNormal(0, 0.02))(x)
        gen_out = layers.Activation("tanh")(x)
        
        return models.Model(inputs, gen_out)
    
    def cross_entropy_loss(y_true, y_pred):
        loss = y_true * K.square(K.maximum(0., 0.9 - y_pred)) + \
            0.5 * (1 - y_true) * K.square(K.maximum(0., y_pred - 0.1))
        return K.mean(K.sum(loss, 1))
    
    def compile(self, optimizer=None, **kwargs):
        if optimizer is None: optimizer = 'binary_crossentropy'
        super(ConvNet_Up, self).compile(optimizer=optimizer, **kwargs)

### Discriminator

#### Discriminative model

In [None]:
class Discriminator:
    def __init__(self, name='Discriminator', **kwargs):
        self.name = name
        self.model = None
        self.input_shape = None
        self.output_shape = None
        self.decoder = None
            
    def build(self, input_shape, output_shape, len_io=1, **kwargs):
        
        self.input_shape = input_shape
        self.output_shape = output_shape
        if hasattr(output_shape, "__len__"): output_shape = np.prod(output_shape)
        
        self.sequential = self.build_sequential(input_shape, output_shape, **kwargs)
        self.sequential.name = 'discriminator'
        print("************************************DISCRIMINATOR*************************************")
        self.sequential.summary()
        inputs = []
        outputs = []
        if len(device_lib.list_local_devices()) >= 3:
            self.sequential = multi_gpu_model(self.sequential, gpus=len(device_lib.list_local_devices())-1)
        for i in range(len_io):
            inputs += [layers.Input(shape=input_shape, name='img'+str(i+1)),]
            outputs += [self.sequential(inputs[i]),]
        
        self.model = models.Model(inputs=inputs,
                            outputs=outputs)
        self.model.name = 'discriminator_in_gan'
        print("************************************DISCRIMINATOR_IN_GAN*************************************")
        self.model.summary()
        
        return self.model
    
    def compile(self, **kwargs):
        self.model.compile(**kwargs)
        print(self.name, "compiled")
        
    def train_generator(self, x, y, shift_fraction=0.):
        train_datagen = ImageDataGenerator(width_shift_range=shift_fraction,
                                           height_shift_range=shift_fraction)  # shift up to 2 pixel for MNIST
        generator = train_datagen.flow(x, y, batch_size=self.batch_size)
        while 1:
            x_batch, y_batch = generator.next()
            if self.decoder : yield ([x_batch, y_batch], [y_batch, x_batch])
            else: yield (x_batch,y_batch)
                
    def test_generator(self, x, y):
        if self.decoder : return ([x, y], [y, x])
        else: return (x, y)
    
    def fit(self, x, y, batch_size, epochs, callbacks=[], load_weights=None, validation_data=None, PlotModel=True, TensorBoard=True, debug=False, ModelCheckpoint= True, CSVLogger=True, logdir='./', **kwargs):
        cb = []
        cb += callbacks
        self.batch_size= batch_size
        if CSVLogger: cb.append(cbks.CSVLogger(os.path.join(logdir, 'history.csv')))
        if TensorBoard: cb.append(cbks.TensorBoard(log_dir=os.path.join(logdir, 'tb'),
                                   batch_size=self.batch_size, histogram_freq=debug))
        if ModelCheckpoint:
            if not os.path.exists(os.path.join(logdir, 'models')): os.makedirs(os.path.join(logdir, 'models'))
            cb.append(cbks.ModelCheckpoint(os.path.join(logdir, 'models/discriminator.h5'),
                                                        save_best_only=False, save_weights_only=True, verbose=1))
        if PlotModel: plot_model(self.model, to_file=os.path.join(logdir, self.name+'.svg'), show_shapes=True)
        if load_weights : self.model.load_weights(load_weights)
        self.model.fit_generator(generator=self.train_generator(x, y, 0.1),
                                  steps_per_epoch=int(x.shape[0] / batch_size),
                                  epochs=epochs,
                                  validation_data=self.test_generator(validation_data[0], validation_data[1]),
                                  callbacks= cb,
                                   **kwargs)

#### ConvNet

In [None]:
class ConvNet(Discriminator):
    def __init__ (self, name='ConvNet', **kwargs):
        super(ConvNet, self).__init__(name=name, **kwargs)
    
    def build_sequential(self, input_shape, output_shape): 
        
        self.input_shape = input_shape
        self.output_shape = output_shape
        if hasattr(output_shape, "__len__"): output_shape = np.prod(output_shape)
        
        # . -> D(.)
        inputs = layers.Input(shape=input_shape)

        #conv layer
        seq=layers.Conv2D(64, kernel_size=4, strides=2, padding='same', use_bias=False, kernel_initializer=initializers.RandomNormal(0,0.02))(inputs)
        seq=layers.LeakyReLU(alpha=0.2)(seq)
        
        #conv layer
        seq=layers.Conv2D(128, kernel_size=4, strides=2, padding='same', use_bias=False, kernel_initializer=initializers.RandomNormal(0,0.02))(seq)
        seq=layers.BatchNormalization(momentum=0.9, epsilon=1.01e-5, gamma_initializer=initializers.RandomNormal(1,0.02))(seq, training=1)
        seq=layers.LeakyReLU(alpha=0.2)(seq)
        
        seq=layers.Conv2D(256, kernel_size=4, strides=2, padding='same', use_bias=False, kernel_initializer=initializers.RandomNormal(0,0.02))(seq)
        seq=layers.BatchNormalization(momentum=0.9, epsilon=1.01e-5, gamma_initializer=initializers.RandomNormal(1,0.02))(seq, training=1)
        seq=layers.LeakyReLU(alpha=0.2)(seq)
        
        seq=layers.Conv2D(2, kernel_size=seq.get_shape().as_list()[1] , strides=1, use_bias=False, kernel_initializer=initializers.RandomNormal(0,0.02))(seq)
        
        seq=layers.Flatten()(seq)
        output = layers.Activation('sigmoid')(seq)
        m = models.Model(inputs, output)
        return m
    
    def loss_fn(self, y_true, y_pred):
        loss = y_true * K.square(K.maximum(0., 0.9 - y_pred)) + \
            0.5 * (1 - y_true) * K.square(K.maximum(0., y_pred - 0.1))
        return K.mean(K.sum(loss, 1))
    
    def compile(self, optimizer=None, **kwargs):
        if optimizer is None: optimizer = self.loss_fn
        super(ConvNet, self).compile(optimizer=optimizer, **kwargs)

#### ConvNet Critic

In [None]:
class ConvNet_Critic(Discriminator):
    def __init__ (self, batch_size, name='ConvNet', **kwargs):
        super(ConvNet_Critic, self).__init__(name=name, **kwargs)
        self.batch_size = batch_size
        print("batch_size", self.batch_size)
        
    def build_sequential(self, input_shape, output_shape): 
        
        self.input_shape = input_shape
        self.output_shape = output_shape
        if hasattr(output_shape, "__len__"): output_shape = np.prod(output_shape)
        
        # . -> D(.)
        inputs = layers.Input(shape=input_shape)

        #conv layer
        seq=layers.Conv2D(64, kernel_size=4, strides=2, padding='same', use_bias=False, kernel_initializer=initializers.RandomNormal(0,0.02))(inputs)
        seq=layers.LeakyReLU(alpha=0.2)(seq)

        #conv layer
        seq=layers.Conv2D(128, kernel_size=4, strides=2, padding='same', use_bias=False, kernel_initializer=initializers.RandomNormal(0,0.02))(seq)
        #seq=layers.BatchNormalization(momentum=0.9, epsilon=1.01e-5, gamma_initializer=initializers.RandomNormal(1,0.02))(seq, training=1)
        seq=layers.LeakyReLU(alpha=0.2)(seq)
        
        seq=layers.Conv2D(256, kernel_size=4, strides=2, padding='same', use_bias=False, kernel_initializer=initializers.RandomNormal(0,0.02))(seq)
        #seq=layers.BatchNormalization(momentum=0.9, epsilon=1.01e-5, gamma_initializer=initializers.RandomNormal(1,0.02))(seq, training=1)
        seq=layers.LeakyReLU(alpha=0.2)(seq)
        
        seq=layers.Conv2D(2, kernel_size=seq.get_shape().as_list()[1], strides=1, use_bias=False, kernel_initializer=initializers.RandomNormal(0,0.02))(seq)
        
        output=layers.Flatten()(seq)
        m = models.Model(inputs, output)
        return m
    
    def build(self, input_shape, output_shape, len_io=1, **kwargs):
        
        class RandomWeightedAverage(layers.merge._Merge):
            def __init__(self, batch_size, **kwargs):
                self.batch_size = batch_size
                super(RandomWeightedAverage, self).__init__(**kwargs)
            """Provides a (random) weighted average between real and generated image samples"""
            def _merge_function(self, inputs):
                alpha = K.random_uniform((self.batch_size, 1, 1, 1))
                return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])
        
        self.input_shape = input_shape
        self.output_shape = output_shape
        if hasattr(output_shape, "__len__"): output_shape = np.prod(output_shape)
        
        self.sequential = self.build_sequential(input_shape, output_shape, **kwargs)
        
        inputs = []
        outputs = []
        x = layers.Input(shape=input_shape)
        G_z = layers.Input(shape=input_shape)
        D_x = self.sequential(x)
        D_G_z = self.sequential(G_z)
        
        interpolated_img = RandomWeightedAverage(self.batch_size)([x, G_z])
        validity_interpolated = self.sequential(interpolated_img)
        self.partial_gp_loss = partial(self.gradient_penalty_loss,
                          averaged_samples=interpolated_img)
        
        self.partial_gp_loss.__name__ = 'gradient_penalty' # Keras requires function names
        
        inputs = [x, G_z]
        outputs = [D_x, D_G_z, validity_interpolated]
        
        self.model = models.Model(inputs=inputs,
                                  outputs=outputs)
        
        self.model.summary()
        return self.model
    
    def loss_fn(self, y_true, y_pred):
        return K.mean(y_true * y_pred)
    
    def gradient_penalty_loss(self, y_true, y_pred, averaged_samples):
        """
        Computes gradient penalty based on prediction and weighted real / fake samples
        """
        gradients = K.gradients(y_pred, averaged_samples)[0]
        # compute the euclidean norm by squaring ...
        gradients_sqr = K.square(gradients)
        #   ... summing over the rows ...
        gradients_sqr_sum = K.sum(gradients_sqr,
                                  axis=np.arange(1, len(gradients_sqr.shape)))
        #   ... and sqrt
        gradient_l2_norm = K.sqrt(gradients_sqr_sum)
        # compute lambda * (1 - ||grad||)^2 still for each single sample
        gradient_penalty = K.square(1 - gradient_l2_norm)
        # return the mean as loss over all the batch samples
        return K.mean(gradient_penalty)
    
    def compile(self, loss, **kwargs):
        loss.append(self.partial_gp_loss)
        super(ConvNet_Critic, self).compile(loss=loss, **kwargs)

#### VCapsNet

VCapsNet Inspired by https://github.com/XifengGuo/CapsNet-Keras

In [None]:
class ParentCaps(layers.Layer):
    def __init__(self, n_caps, dim_caps, routing_iters,
                 kernel_initializer='glorot_uniform',
                 **kwargs):
        super(ParentCaps, self).__init__(**kwargs)
        self.n_caps = n_caps
        self.dim = dim_caps
        self.routing_iters = routing_iters
        self.kernel_initializer = initializers.get(kernel_initializer)

    def build(self, input_shape):
        assert len(input_shape) >= 3, "The input Tensor should have shape=[None, n_caps_in, dim_im]"
        self.n_caps_in = input_shape[1]
        self.dim_im = input_shape[2]

        # Transform matrix
        self.W = self.add_weight(shape=[self.n_caps, self.n_caps_in,
                                        self.dim, self.dim_im],
                                 initializer=self.kernel_initializer,
                                 name='W')

        self.built = True

    def call(self, inputs, training=None):
        inputs_hat = self.spatial_transform(inputs)
        outputs = self.routing(inputs_hat)
        return outputs

    def spatial_transform(self, inputs):
        inputs_expand = K.expand_dims(inputs, 1)
        inputs_tiled = K.tile(inputs_expand, [1, self.n_caps, 1, 1])
        inputs_hat = K.map_fn(lambda x: K.batch_dot(x, self.W, [2, 3]), elems=inputs_tiled)
        return inputs_hat

    def routing(self, inputs_hat):    
        b = tf.zeros(shape=[K.shape(inputs_hat)[0], self.n_caps, self.n_caps_in])
        for i in range(self.routing_iters):
            c = tf.nn.softmax(b, dim=1)
            #outputs = layers.advanced_activations.LeakyReLU(alpha=0.2)(K.batch_dot(c, inputs_hat, [2, 2]))
            outputs = squash(K.batch_dot(c, inputs_hat, [2, 2]))
            if i < self.routing_iters - 1:
                b += K.batch_dot(outputs, inputs_hat, [2, 3])
        return outputs #layers.advanced_activations.LeakyReLU()(outputs)

    def compute_output_shape(self, input_shape):
        return tuple([None, self.n_caps, self.dim])

    def get_config(self):
        config = {
            'n_caps': self.n_caps,
            'dim': self.dim,
            'routings': self.routing_iters
        }
        base_config = super(ParentCaps, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

class Mask(layers.Layer):
    def call(self, inputs, **kwargs):
        if type(inputs) is list:
            assert len(inputs) == 2
            inputs, mask = inputs
        else:
            x = K.sqrt(K.sum(K.square(inputs), -1))
            mask = K.one_hot(indices=K.argmax(x, 1), num_classes=x.get_shape().as_list()[1])
        masked = K.batch_flatten(inputs * K.expand_dims(mask, -1))
        return masked
    def compute_output_shape(self, input_shape):
        if type(input_shape[0]) is tuple:  # true label provided
            return tuple([None, input_shape[0][1] * input_shape[0][2]])
        else:  # no true label provided
            return tuple([None, input_shape[1] * input_shape[2]])
    def get_config(self):
        config = super(Mask, self).get_config()
        return config

class Length(layers.Layer):
    def call(self, inputs, **kwargs):
        return K.sqrt(K.sum(K.square(inputs), -1))
    def compute_output_shape(self, input_shape):
        return input_shape[:-1]

In [None]:
class VCapsNet(Discriminator):
    def __init__ (self, name='VCapsNet', **kwargs):
        super(VCapsNet, self).__init__(name=name, **kwargs)
        
    def build_sequential(self, input_shape, output_shape, L1_n, L2_n, L2_dim, L3_dim, routing=3, decoder=False, L4_n=512, L5_n=1024): #L3_n is the same as output_shape
        self.decoder = decoder
        # Input
        imgs = layers.Input(shape=input_shape)
        # Conv layer
        x = layers.Conv2D(filters=L1_n, kernel_size=9, strides=1, kernel_initializer=initializers.RandomNormal(0,0.02), padding='valid', name='conv1')(imgs)
        x = layers.BatchNormalization(momentum=0.9, epsilon=1.01e-5, gamma_initializer=initializers.RandomNormal(1,0.02))(x)
        x = layers.LeakyReLU(alpha=0.2)(x)
        #x = layers.Conv2D(filters=L1_n, kernel_size=9, strides=1, kernel_initializer=initializers.RandomNormal(0,0.02), padding='same', name='conv2')(x)
        #x = layers.BatchNormalization(momentum=0.9, epsilon=1.01e-5, gamma_initializer=initializers.RandomNormal(1,0.02))(x)
        #x = layers.LeakyReLU(alpha=0.2)(x)
        #x = layers.Conv2D(filters=L1_n, kernel_size=9, strides=1, kernel_initializer=initializers.RandomNormal(0,0.02), padding='same', name='conv3')(x)
        #x = layers.BatchNormalization(momentum=0.9, epsilon=1.01e-5, gamma_initializer=initializers.RandomNormal(1,0.02))(x)
        #x = layers.LeakyReLU(alpha=0.2)(x)
        # [None, num_capsule, dim_capsule]
        x = self.PrimaryCaps(x, dim_caps=L2_dim, n_channels=L2_n, kernel_size=9, strides=2, padding='valid')
        # routing algo
        digitcaps = ParentCaps(n_caps=output_shape, dim_caps=L3_dim, routing_iters=routing, kernel_initializer=initializers.RandomNormal(0,0.02), name="ParentCaps_")(x)
        # This is an auxiliary layer to replace each capsule with its length to match the true label's shape.
        pred = Length(name="capsnet")(digitcaps)
        if decoder:
            # Decoder network
            labels = layers.Input(shape=(output_shape,))
            masked_by_labels = Mask()([digitcaps, labels])  # The true label is used to mask the output of capsule layer. For training
            masked = Mask()(digitcaps)  # Mask using the capsule with maximal length. For prediction
            decoder = models.Sequential(name='decoder')
            decoder.add(layers.Dense(L4_n, activation='relu', input_dim=output_shape*L3_dim))
            decoder.add(layers.Dense(L5_n, activation='relu'))
            decoder.add(layers.Dense(np.prod(disc_dict["inputs_shape"]), activation='tanh', name="Disc_L5_Decoder_FC3"))
            decoder.add(layers.Reshape(target_shape=disc_dict["inputs_shape"], name='decoder'))
            # Models for training and evaluation (prediction)
            return models.Model([imgs, labels], [pred, decoder(masked_by_labels)])
            #eval_model = models.Model(imgs, [pred, decoder(masked)])
        else:
            return models.Model(imgs, pred)
    
    def PrimaryCaps(self, inputs, dim_caps, n_channels, kernel_size, strides, padding):
        output = layers.Conv2D(filters=dim_caps*n_channels, kernel_size=kernel_size, strides=strides, padding=padding, kernel_initializer=initializers.RandomNormal(0,0.02))(inputs)
        output = layers.Reshape(target_shape=[-1, dim_caps])(output)
        output = layers.BatchNormalization(momentum=0.9, epsilon=1.01e-5, gamma_initializer=initializers.RandomNormal(1,0.02))(output)
        #output = layers.LeakyReLU(alpha=0.2)(output)
        output = layers.Lambda(squash)(output)
        return output
    
    def loss_fn(self, y_true, y_pred):
        loss = y_true * K.square(K.maximum(0., 0.9 - y_pred)) +  0.5 * (1 - y_true) * K.square(K.maximum(0., y_pred - 0.1))
        return K.mean(K.sum(loss, 1))
    
    def compile(self, optimizer=None, **kwargs):
        if optimizer is None: optimizer = self.loss_fn
        super(VCapsNet, self).compile(optimizer=optimizer, **kwargs)

#### Matrix CapsNet

In [None]:
class CapsLayer(layers.Layer):
    def __init__(self, n_caps_out, pose_shape, kernel_size=None, strides=None, routing_iters=3, trainable=True, name=None, **kwargs):
        self.n_caps_out = n_caps_out
        self.pose_shape = pose_shape 
        self.routing_iters = routing_iters
        self.strides = strides
        if kernel_size: # in case conv capsules
            self.kernel_size= kernel_size
            self.spatial_dim = [1, 1]
        else: # in case class capsules
            self.kernel_size = [1, 1]
            self.spatial_dim = []
        super(CapsLayer, self).__init__(trainable=trainable, name=name, **kwargs)
    def build(self, input_shape):
        self.pose_shape_in = [int(np.sqrt(input_shape[-1]-1)),int(np.sqrt(input_shape[-1]-1))]
        self.n_caps_in = input_shape[-2]
        self.spatial_size_in=[int(input_shape[1]), int(input_shape[2])]
        self.spatial_size=self.spatial_size_in
        # beta_v: SHAPE=[1, (1, 1,) 1, O, 1], TYPE=tensor, VALUE= trainable parameter (vector of dim: # capsules in layer L+1)
        self.beta_v = self.add_weight(shape=[1,] + self.spatial_dim +[1, self.n_caps_out, 1],
                                    initializer=initializers.glorot_normal(),
                                    name='beta_v')
        # beta_a: SHAPE=[1, (1, 1,) 1, O, 1], TYPE=tensor, VALUE= trainable parameter (vector of dim: # capsules in layer L+1)b, s, s, k, k, 1, B, p, p
        self.beta_a = self.add_weight(shape=[1,] + self.spatial_dim +[1, self.n_caps_out, 1],
                                    initializer=initializers.glorot_normal(),
                                    name='beta_a')
        # W_ij: SHAPE=[1, 1, 1, (k0*k1*)I, O, p0, p1]
        self.W_ij = self.add_weight(shape=[1, 1, 1, self.kernel_size[0]*self.kernel_size[1]*self.n_caps_in, self.n_caps_out, self.pose_shape[0], self.pose_shape[1]],
                                    initializer=initializers.RandomNormal(mean=0.0, stddev=0.05),
                                  name='W_ij') #vll. hier die 1 durch size_batch ersetzen
        # run build method with __init__
        self.built = True 
        super(CapsLayer, self).build(input_shape)
    def call(self, inputs):
        self.batch_size = K.shape(inputs)[0]
        ################ inputs ################
        # M_i:SHAPE=[b, s0, s1, I, p0*p1], TYPE= tensor, VALUE= pose matrix
        # a_i:SHAPE=[b, s0, s1, I], TYPE= tensor, VALUE= activations
        M_i = inputs[:,:,:,:,:16]
        a_i = inputs[:,:,:,:,16]

        M_i= K.reshape(M_i, shape=[-1, self.spatial_size[0], self.spatial_size[1], self.n_caps_in, self.pose_shape_in[0],self.pose_shape_in[1]])
        ################ depthwise conv ################
        # M_i:SHAPE=[b, s0', s1', (k0*k1*)I, p0*p1], TYPE= tensor, VALUE= pose matrix
        # a_i:SHAPE=[b, s0', s1', (k0*k1*)I], TYPE= tensor, VALUE= activations
        if len(self.spatial_dim): M_i, a_i = self.depthwise_conv(M_i, a_i)
        ################ spatial tansformation ################
        # V_ij: SHAPE=[b, s0', s1', (k0*k1*)I, O, p0, p1], TYPE= tensor, VALUE= learn the spatial transformations of the features
        V_ij = self.spatial_transform(M_i)
        ################ coordinate addition ################
        #V_ij: SHAPE= [b, s0', s1', I, O, p0, p1], TYPE= tensor, VALUE= new vote matrix with values addition along an axis
        if not len(self.spatial_dim): V_ij = self.coord_addition(V_ij)
        ################ EM routing ################
        # M_j: SHAPE=[b, s0', s1', O, p0, p0], TYPE= tensor, VALUE= pose matrix of the new capsules' layer
        # a_j: SHAPE=[b, s0', s1', O], TYPE= tensor, VALUE= activations of the new capsules' layer
        M_j, a_j = self.em_routing(V_ij, a_i)

        if len(self.spatial_dim):
            M_j = K.reshape(M_j, [-1, self.spatial_size[0], self.spatial_size[1], self.n_caps_out, self.pose_shape[0]*self.pose_shape[1]])
            a_j = K.expand_dims(a_j, -1)
            a_j = layers.Activation('sigmoid')(a_j)
            net = layers.Concatenate()([M_j,a_j])
        else:
            net = layers.Activation('sigmoid')(a_j)
        return net
    def depthwise_conv(self, M_i, a_i):
        def depthwise_operation (input, kernel, stride):
            # (?, 14, 14, 32x(16)=512)
            input_shape = input.get_shape()
            size = input_shape[4]*input_shape[5] if len(input_shape)>5 else 1
            input = tf.reshape(input, shape=[-1, input_shape[1], input_shape[2], input_shape[3]*size])
            tile_filter = np.zeros(shape=[kernel, kernel, input_shape[3],
                                          kernel * kernel], dtype=np.float32)
            for i in range(kernel):
                for j in range(kernel):
                    tile_filter[i, j, :, i * kernel + j] = 1.0 # (3, 3, 512, 9)
            # (3, 3, 512, 9)
            tile_filter_op = tf.constant(tile_filter, dtype=tf.float32)
            # (?, 6, 6, 4608)
            output = tf.nn.depthwise_conv2d(input, tile_filter_op, strides=[
                                            1, stride, stride, 1], padding='VALID')
            output_shape = output.get_shape()
            output = tf.reshape(output, shape=[-1, output_shape[1], output_shape[2], input_shape[3], kernel * kernel])
            output = tf.transpose(output, perm=[0, 1, 2, 4, 3])
            return output
        # M_i: SHAPE=[b, s1, s2, I*p1*p2], TYPE= tensor, VALUE= prepare the tensor for a depthconv
        M_i= K.reshape(M_i, shape=[self.batch_size, self.spatial_size_in[0], self.spatial_size_in[1], self.n_caps_in*self.pose_shape_in[0]*self.pose_shape_in[1]])
        # M_i: SHAPE=[b, s1, s2, k1*k2, I*p1*p2], TYPE= tensor, VALUE= tiled pose matrix to be mutiplied by the transformation matrices to generate the votes
        M_i = depthwise_operation(M_i, kernel=self.kernel_size[0], stride=self.strides)
        # spatial_size: SHAPE=[1], TYPE= int, VALUE= new spatial size of the capsule (after the convolution)
        self.spatial_size = [int(M_i.shape[1]), int(M_i.shape[2])]
        # M_i: SHAPE=[b, s0', s1', k1*k2*I, p1, p2], TYPE= tensor, VALUE= reshape the pose matrix back to its standard shape
        M_i= K.reshape(M_i, shape=[self.batch_size, self.spatial_size[0], self.spatial_size[1], self.kernel_size[0]*self.kernel_size[1]*self.n_caps_in, self.pose_shape_in[0],self.pose_shape_in[1]])
        # a_i: SHAPE=[b, s1', s2', k1*k2, I], TYPE= tensor, VALUE= tiled activations
        a_i = depthwise_operation(a_i, kernel=self.kernel_size[0], stride=self.strides)
        # a_i: SHAPE=[b, s1', s2', k1*k2*I], TYPE= tensor, VALUE= reshape the activation back to its standard shape
        a_i= K.reshape(a_i, shape=[self.batch_size, self.spatial_size[0], self.spatial_size[1], self.kernel_size[0]*self.kernel_size[1]*self.n_caps_in])
        return M_i, a_i
    def spatial_transform(self, M_i):
        # M_i: SHAPE=[b, s0', s1', (k0*k1*)I, 1, p0, p1], TYPE= tensor, VALUE= expand the tensor with a value equal to the number output caps
        M_i = K.expand_dims(M_i, -3)
        # M_i: SHAPE=[b, s0', s1', (k0*k1*)I, O, p0, p1], TYPE= tensor, VALUE= expand the tensor with a value equal to the number output caps
        M_i = K.tile(M_i, [1, 1, 1, 1, self.n_caps_out, 1, 1])

        # W_ij: SHAPE=[b, s0', s1', (k0*k1*)I, O, p0, p1], VALUE= tiled transformation matrices, tile to batch_size
        W_ij= K.tile(self.W_ij, [self.batch_size, self.spatial_size[0], self.spatial_size[1], 1, 1, 1, 1])

        # V_ij: SHAPE=[b, s0', s1', (k0*k1*)I, O, p0, p1], TYPE= tensor, VALUE= vote matrices
        V_ij = K.batch_dot(M_i, W_ij)
        return V_ij
    def coord_addition(self, V_ij):
        """
        From the paper: "We therefore share the transformation matrices between different positions of the same capsule type and
        add the scaled coordinate (row, column) of the center of the receptive field of each capsule to the first
        two elements of the right-hand column of its vote matrix."
        """
        # V_ij: SHAPE=[b, s0', s1', k0*k1*I, O, p0*p1], TYPE= tensor, VALUE= adapt the shape for computation
        V_ij = K.reshape(V_ij, shape=[self.batch_size, self.spatial_size[0], self.spatial_size[1], self.kernel_size[0]*self.kernel_size[1]*self.n_caps_in, self.n_caps_out, self.pose_shape_in[0]*self.pose_shape_in[1]])
        # H_values: SHAPE=[1, s0', 1, 1, 1], TYPE= tensor, VALUE= variational axis
        H_values = K.reshape((tf.range(self.spatial_size[0], dtype=tf.float32) + 0.50) / self.spatial_size[0], [1, self.spatial_size[0], 1, 1, 1])
        # H_values: SHAPE=[1, s0', 1, 1, 1], TYPE= tensor, VALUE= non variational axis
        H_zeros = tf.constant(0.0, shape=[1, self.spatial_size[0], 1, 1, 1], dtype=tf.float32)
        # H_values: SHAPE=[1, s0', 1, 1, p0*p1], TYPE= tensor, VALUE= new coordinates' offset
        H_offset = tf.stack([H_values, H_zeros] + [H_zeros for _ in range(self.pose_shape_in[0]*self.pose_shape_in[1]-2)], axis=-1) 
        # W_values: SHAPE=[1, 1, s1', 1, 1, 1], TYPE= tensor, VALUE= variational axis
        W_values = tf.reshape((tf.range(self.spatial_size[1], dtype=tf.float32) + 0.50) / self.spatial_size[1], [1, 1, self.spatial_size[1], 1, 1])
        # H_values: SHAPE=[1, 1, s1', 1, 1], TYPE= tensor, VALUE= non variational axis
        W_zeros = tf.constant(0.0, shape=[1, 1, self.spatial_size[1], 1, 1], dtype=tf.float32)
        # H_values: SHAPE=[1, 1, s1', 1, p0*p1], TYPE= tensor, VALUE= new coordinates' offset
        W_offset = tf.stack([W_zeros, W_values] + [W_zeros for _ in range(self.pose_shape_in[0]*self.pose_shape_in[1]-2)], axis=-1)
        # V_ij: SHAPE=[b, s0', s1', I, O, p0*p1], TYPE= tensor, VALUE= V_ij in the new coordinates
        V_ij = V_ij + H_offset + W_offset
        # V_ij: SHAPE=[b, s0', s1', I, O, p0, p1], TYPE= tensor, VALUE= reshape back to the standard norm
        V_ij = K.reshape(V_ij, shape=[self.batch_size, self.spatial_size[0], self.spatial_size[1], self.kernel_size[0]*self.kernel_size[1]*self.n_caps_in, self.n_caps_out, self.pose_shape_in[0], self.pose_shape_in[1]])
        return V_ij
    def em_routing(self, V_ij, a_i):
        def maximization(R_ij, V_ij, a_i, inv_temp):
            # R_ij: SHAPE=[b, s0', s1', k0*k1*I, O, 1] or [b, s0'*s1'*I, O, 1], TYPE= tensor, VALUE=weights assignment according to the activation probabilities CAUTION!!!! maybe reshape it into k, k, A, B, 1 .... before multiplication
            R_ij = R_ij * a_i
            # R_ij: SHAPE=[b, (s0', s1',) 1, O, 1] , TYPE= tensor, VALUE=sum over all input capsules i
            R_ij_sum = K.sum(R_ij, axis=-3, keepdims=True)
            # M_j: SAHPE=[b, (s0', s1',) 1, O, p0*p1], TYPE= tensor, VALUE= mean of capsule j
            M_j = K.sum(R_ij * V_ij, axis=-3, keepdims=True ) / R_ij_sum
            # stdv_j: SAHPE=[b, (s0', s1',) 1, O, p0*p1], TYPE= tensor, VALUE= standard deviation of capsule j
            stdv_j = K.sqrt(K.sum(R_ij_sum * tf.square(V_ij - M_j), axis=-3, keepdims=True) / R_ij_sum)
            # cost_j_h: SHAPE=[b, (s0', s1',) 1, O, p0*p1], TYPE= tensor, VALUE= expected energy of a capsule j
            cost_j_h = (self.beta_v + K.log(stdv_j + K.epsilon())) * R_ij_sum
            # cost_j: SHAPE=[b, (s0', s1',) 1, O, 1], TYPE= tensor, VALUE= expected energy
            cost_j = K.sum(cost_j_h, axis=-1, keepdims=True)
            # cost_j_mean: SHAPE=[b, (s0', s1',) 1, 1, 1], TYPE= tensor, VALUE= mean the expected energy over the output capsules
            cost_j_mean = K.mean(cost_j, axis=-2, keepdims=True)
            # cost_j_stdv: SHAPE=[b, (s0', s1',) 1, 1, 1], TYPE= tensor, VALUE= mean the expected energy
            cost_j_stdv = K.sqrt(K.sum(K.square(cost_j - cost_j_mean), axis=-2, keepdims=True) / self.n_caps_out)
            # a_j_cost: SHAPE=[b, (s0', s1',) 1, O, 1], TYPE= tensor, VALUE= cost of the activation of capsule j
            a_j_cost = self.beta_a + (cost_j_mean - cost_j) / (cost_j_stdv + K.epsilon())
            # a_j: SHAPE=[b, (s0', s1',) 1, O, 1], TYPE= tensor, VALUE= activation of capsule j
            a_j = tf.sigmoid(inv_temp * a_j_cost)
            # a_j: SHAPE=[b, (s0', s1',) O], TYPE= tensor, VALUE= squeezed activation of capsule j
            a_j = K.squeeze(K.squeeze(a_j, axis=-3), axis=-1)
            # M_j: SAHPE=[b, (s0', s1',) O, p0*p1], TYPE= tensor, VALUE=squeezed mean of capsule j (pose matrix)
            M_j = K.squeeze(M_j, axis=-3)
            # stdv_j: SAHPE=[b, (s0', s1',) O, p0*p1], TYPE= tensor, VALUE=squeezed standard deviation of capsule j
            stdv_j = K.squeeze(stdv_j, axis=-3)
            return M_j, stdv_j, a_j
        def estimation(M_j, stdv_j, V_ij, a_j):
            # M_j: SAHPE=[b, (s0', s1',) 1, O, p0*p1], TYPE= tensor, VALUE=squeezed mean of capsule j (pose matrix)
            M_j = K.expand_dims(M_j, -3)     
            # a_j: SAHPE=[b, (s0', s1',) 1, O, 1], TYPE= tensor, VALUE=squeezed mean of capsule j (pose matrix)
            a_j = K.expand_dims(K.expand_dims(a_j, -2), -1)
            # stdv_j: SAHPE=[b, (s0', s1',) 1, O, p0*p1], TYPE= tensor, VALUE=squeezed mean of capsule j (pose matrix)
            stdv_j = K.expand_dims(stdv_j, -3)
            # R_ij: SHAPE= [b, s0', s1', k0*k1*I, O, 1] or [b, s0'*s1'*I, O, 1], TYPE= tensor, VALUE= routing matrix
            a_j_p_j  = K.log(a_j + K.epsilon()) - K.sum(K.square(V_ij - M_j) /(2 * tf.square(stdv_j)), axis=-1, keepdims=True) - K.sum(tf.log(stdv_j + K.epsilon()), axis=-1, keepdims=True)
            # R_ij: SHAPE= [b, s0', s1', k0*k1*I, O, 1], TYPE= tensor, VALUE= activated routing matrix
            R_ij = tf.nn.softmax(a_j_p_j, dim=len(a_j_p_j.get_shape().as_list())-2)
            return R_ij
        if len(self.spatial_dim):
            # V_ij: SHAPE=[b, s0', s1', k0*k1*I, O, p0*p1], TYPE= tensor, VALUE= adapt the shape for computation
            V_ij = K.reshape(V_ij, shape=[self.batch_size, self.spatial_size[0], self.spatial_size[1], self.kernel_size[0]*self.kernel_size[1]*self.n_caps_in, self.n_caps_out, self.pose_shape[0]*self.pose_shape[1]])
            # ai: SHAPE=[b, s0', s1', k0*k1*I, 1, 1], TYPE= tensor, VALUE= expanded i activations
            a_i = K.expand_dims(K.expand_dims(a_i,-1),-1)
            # R_ij: SHAPE=[k0*k1*I, O, 1], TYPE= tensor, VALUE= routing assignment matrix from each input capsule (i) in L to each output capsule (j) in L+1 initilized with uniform distribution
            R_ij = K.constant(1.0/self.n_caps_out, shape=(self.kernel_size[0]*self.kernel_size[0]*self.n_caps_in, self.n_caps_out, 1))
        else:
            # V_ij: SHAPE=[b, s0'*s1'*I, O, p0*p1], TYPE= tensor, VALUE= adapt the shape for computation
            V_ij = K.reshape(V_ij, shape=[self.batch_size, self.spatial_size[0]*self.spatial_size[1]*self.n_caps_in, self.n_caps_out, self.pose_shape[0]*self.pose_shape[1]])
            # a_i: SHAPE=[b, s0'*s1'*I, 1, 1], TYPE= tensor, VALUE= reshape to standard form
            a_i = K.reshape(a_i, shape=[self.batch_size, self.spatial_size[0]*self.spatial_size[1]*self.n_caps_in, 1, 1])
            # R_ij: SHAPE=[s0'*s1'*I, O, 1], TYPE= tensor, VALUE= routing assignment matrix from each input capsule (i) in L to each output capsule (j) in L+1 initilized with uniform distribution
            R_ij = K.constant(1.0/self.n_caps_out, shape=(self.spatial_size[0]*self.spatial_size[0]*self.n_caps_in, self.n_caps_out, 1))
        for iter in range(self.routing_iters):
            # inv_temp: SHAPE=[1], TYPE= int, VALUE= Lambda: inverse temperature schedule (1, min(routing_iters, 3.0)-1)
            inv_temp = 1.0 + (min(self.routing_iters, 3.0) - 1.0) * iter / max(1.0, self.routing_iters - 1.0)
            # M_j: SAHPE=[b, (s0', s1',) O, p0*p1], TYPE= tensor, VALUE= mean of capsule j
            # stdv_j: SAHPE=[b, (s0', s1',) O, p0*p1], TYPE= tensor, VALUE= standard deviation of capsule j
            # a_j: SHAPE= [b, (s0', s1',) O], TYPE= tensor, VALUE= activation of capsule j
            M_j, stdv_j, a_j = maximization(R_ij, V_ij, a_i, inv_temp=inv_temp)
            # R_ij: SHAPE= [b, (s0', s1',) k0*k1*I, O, 1], TYPE= tensor, VALUE= activated routing matrix
            if iter < self.routing_iters - 1:
                R_ij = estimation(M_j, stdv_j, V_ij, a_j)
        # M_j: SHAPE=[b, (s0', s1',) O, p0, p1], TYPE= tensor, VALUE= reshape back to the standard norm
        M_j = K.reshape(M_j, shape=[self.batch_size, self.spatial_size[0], self.spatial_size[1], self.n_caps_out, self.pose_shape[0], self.pose_shape[1]])
        return M_j, a_j
    def compute_output_shape(self, input_shape):
        if len(self.spatial_dim):
            # M_j: SHAPE=[b, (s0', s1',) O, p0, p1], TYPE= tensor, VALUE= pose matrix of the new capsules' layer (reshaped back to p0xp1 pose matrix)
            output_sh=[input_shape[0], self.spatial_size[0], self.spatial_size[1], self.n_caps_out, self.pose_shape[0]*self.pose_shape[1]+1]
        else: 
            output_sh=[input_shape[0], self.n_caps_out]
        return tuple(output_sh)

In [None]:
class MCapsNet(Discriminator):
    def __init__ (self, batch_size, name='MCapsNet', **kwargs):
        self.batch_size = batch_size
        super(MCapsNet, self).__init__(name=name, **kwargs)
        
    def build_sequential(self, input_shape, output_shape, L1_n, L2_n, L3_n, L4_n, pose_shape=[4,4], routing=3, decoder=False): #L3_n is the same as output_shape     
        self.output_shape = output_shape
        # inputs = img_shape : Input
        inputs = layers.Input(shape=input_shape)
        # net = [b, s0, s1, A] : ReLU Conv1
        net = layers.Conv2D(filters=L1_n, kernel_size=[5,5], strides=2, padding='SAME', activation='relu', name='ReLU_Conv1')(inputs) # add batch normalization ?? 
        # net = [ poses = [?, s0, s1, B, p*p], activations = [?, s0, s1, B] ] : PrimaryCaps
        net = self.PrimaryCaps(net, pose_shape=pose_shape, n_caps_out=L2_n, kernel_size=[1,1], strides=1, padding='VALID', name='PrimaryCaps')
        # nets = [ poses = [?, s0', s1', C, p*p], activations = [?, s0', s1', C] ] : ConvCaps1
        net = CapsLayer(n_caps_out=L3_n, pose_shape=pose_shape,  kernel_size=[3,3], strides=2,routing_iters=routing, name='ConvCaps1')(net)
        # nets = [ poses (?, s0'', s1'', D, p*p), activations = [?, s0'', s1'', D] ] : ConvCaps2
        net = CapsLayer(n_caps_out=L4_n, pose_shape=pose_shape,  kernel_size=[3,3], strides=1,routing_iters=routing, name='ConvCaps2')(net)
        # output  = [ poses = [?, E, p*p], activations = [?, E] ] : Class Capsules
        net = CapsLayer(n_caps_out=np.prod(output_shape), pose_shape=pose_shape, routing_iters=routing, name='Class_Capsules')(net)
        return models.Model(inputs, net)
    
    def PrimaryCaps(self,inputs, pose_shape, n_caps_out, kernel_size, strides, padding, name):
        #M = [b, s0, s1, I*p0*p1] : generate the pose matrices of the caps
        M = layers.Conv2D(filters=n_caps_out*pose_shape[0]*pose_shape[1], kernel_size=kernel_size, strides=strides, padding=padding)(inputs)
        #M = [b, s0, s1, I, p0, p1] : reshape the pose matrices from 16 scalar values into a 4x4 matrix
        M = layers.Reshape(target_shape=[M.get_shape().as_list()[1], M.get_shape().as_list()[2], n_caps_out, pose_shape[0]*pose_shape[1]])(M)
        #a = [b, s0, s1, I] : generate the activation for the caps
        a = layers.Conv2D(filters=n_caps_out, kernel_size=kernel_size, strides=strides, padding=padding)(inputs)
        a = layers.Activation('sigmoid')(a)
        a = layers.Reshape(target_shape=[inputs.get_shape().as_list()[1], inputs.get_shape().as_list()[2], n_caps_out, 1])(a)
        net = layers.Concatenate()([M,a])
        return net
    
    def compile(self, batch_size, n_samples, optimizer=None, **kwargs):
        self.batch_size = batch_size
        iterations_per_epoch = int(n_samples / batch_size)
        self.margin = tf.train.piecewise_constant(tf.Variable(1, trainable=False, dtype=tf.int32),
                                                 boundaries=[ int(iterations_per_epoch * 10.0 * x /7) for x in range(1, 8)], 
                                                 values=[x / 10.0 for x in range(2, 10)])
        if optimizer is None: optimizer = self.loss_fn
        super(MCapsNet, self).compile(optimizer=optimizer, **kwargs)
        
    def fit(self, x, y, batch_size, **kwargs):
        super(MCapsNet, self).fit(x=x, y=y, batch_size=batch_size, **kwargs)
    
    def loss_fn(self, y_true, y_pred):
        # y_pred_t = [b, 1] : true predictions
        y_pred_true = K.reshape(tf.boolean_mask(y_pred,tf.equal(y_true, 1)), shape=(self.batch_size, 1))
        # y_pred_i = [b, 9] : false predictions
        y_pred_false = K.reshape(tf.boolean_mask(y_pred,tf.equal(y_true, 0)), shape=(self.batch_size, np.prod(self.output_shape)-1))
        # loss = [1] : loss function
        #loss = K.sum(K.square(K.relu(self.margin - (y_pred_true - y_pred_false))))
        loss = K.sum(K.square(K.maximum(0., self.margin - (y_pred_true - y_pred_false))))
        return loss

### GAN

#### Generative model

Training and evaluation (incl. Inception score and Fréchet Inception Distance) engines of GAN.
fit_generator function has a similar strcture as a keras training engine.

In [None]:
class GAN:
    def __init__(self, name='GAN', **kwargs):
        self.name = name
        self.model = None
        self.G = None
        self.D = None
        
    def build_compile(self):
        pass
    
    def split_kwargs(self, **kwargs):
        kwargs_gen = {};
        kwargs_disc = {};
        for key, value in kwargs.items():
            if 'G' in value.keys():
                kwargs_gen[key]=value['G']
            if 'D' in value.keys():
                kwargs_disc[key]=value['D']
        return kwargs_gen, kwargs_disc 
        
    def write_log(self,callback, names, logs, batch_no):
        for name, value in zip(names, logs):
            summary = tf.Summary()
            summary_value = summary.value.add()
            summary_value.simple_value = value
            summary_value.tag = name
            callback.writer.add_summary(summary, batch_no)
            callback.writer.flush()
            
    def get_eval_scores(self, splits=10, n_samples=None):
        def preprocessing(imgs):
            if (len(np.shape(imgs))==3 or np.shape(imgs)[-1]==1):
                imgs = np.squeeze(imgs)
                imgs = np.stack((imgs,)*3, -1)
                imgs = np.rollaxis(imgs, 3, -1)
            return imgs
        
        def inception_logits(images, num_splits=1):
            #images=tf.transpose(images,[0,2,3,1])
            size = 299
            images = tf.image.resize_bilinear(images, [size, size])
            generated_images_list = array_ops.split(
            images, num_or_size_splits=num_splits)
            logits = functional_ops.map_fn(
                fn=functools.partial(tf.contrib.gan.eval.run_inception, output_tensor='logits:0'),
                elems=array_ops.stack(generated_images_list),
                parallel_iterations=1,
                back_prop=False,
                swap_memory=True,
                name='RunClassifier')
            logits = array_ops.concat(array_ops.unstack(logits), 0)
            return logits
        if n_samples is None:
            n_samples = self.n_samples
        G_z = self.G.model.predict(np.random.normal(0, 1, (n_samples,)+self.G.input_shape))
        x = self.imgs[:n_samples]
        imgs_g = preprocessing(G_z)
        imgs_x = preprocessing(x)
        BATCH_SIZE= 100
        session = tf.InteractiveSession()
        # Run images through Inception.
        inception_images_x=tf.placeholder(tf.float32,[BATCH_SIZE,None,None, 3])
        # Run images through Inception.
        inception_images_g=tf.placeholder(tf.float32,[BATCH_SIZE,None,None, 3])
        
        logits_x=inception_logits(inception_images_x)
        logits_g=inception_logits(inception_images_g)

        def get_inception_probs(inps, inception_imgs, logits):
            preds = []
            n_batches = len(inps)//BATCH_SIZE
            for i in range(n_batches):
                inp = inps[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
                pred = logits.eval({inception_imgs:inp})[:,:1000]
                preds.append(pred)
                
            preds = np.concatenate(preds, 0)
            preds=np.exp(preds)/np.sum(np.exp(preds),1,keepdims=True)
            return preds

        def preds2IS(preds_g, splits):
            scores = []
            for i in range(splits):
                part = preds_g[(i * preds_g.shape[0] // splits):((i + 1) * preds_g.shape[0] // splits), :]
                kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
                kl = np.mean(np.sum(kl, 1))
                scores.append(np.exp(kl))
            return np.mean(scores), np.std(scores)
        
        def pred2FID(real_images, generated_images, batch_size,
                                   num_inception_images):
            size = 299
            resized_real_images = tf.image.resize_bilinear(real_images, [size, size])
            resized_generated_images = tf.image.resize_bilinear(
              generated_images, [size, size])

            # Compute Frechet Inception Distance.
            num_batches = batch_size // num_inception_images
            tfgan = tf.contrib.gan
            fid = tfgan.eval.frechet_inception_distance(
              resized_real_images, resized_generated_images, num_batches=num_batches)
            sess = tf.Session()
            sess.run(tf.global_variables_initializer())
            return sess.run(fid)
            
        def get_score(splits):
            preds_g = get_inception_probs(imgs_g, inception_images_g, logits_g)
            preds_x = get_inception_probs(imgs_x, inception_images_x, logits_x)
            mean,std = preds2IS(preds_g,splits)
            fid = pred2FID(imgs_x, imgs_g, BATCH_SIZE, splits)
            return mean,std, fid  # Reference values: 11.34 for 49984 CIFAR-10 training set images, or mean=11.31, std=0.08 if in 10 splits (default).
        score = get_score(splits)
        tf.InteractiveSession.close(session)
        return score
    
    def plot_generated_samples(self, imgs, grid=[10,10], imgsize=[10,10], img_range=[0,255], cmap='gray', logdir=None, img_name=None):
        fig= plt.figure(figsize=(imgsize[0], imgsize[1]))
        if True:
            # Get random samples
            imgs= imgs[0:grid[0]*grid[1]]
            if cmap is 'gray': imgs = np.squeeze(imgs, -1)
            if np.shape(imgs)[0] is 1: imgs = np.squeeze(imgs, 0)
            imgs = ((imgs-img_range[0])*255/(img_range[1]-img_range[0])).astype(np.uint8)
            # Create a figure object

            # Show images
            for i in range(0, grid[0]*grid[1]):        
                fig.add_subplot(grid[0], grid[1], i+1)
                img = imgs[i]
                plt.imshow(img, cmap=cmap)
                plt.axis('off')
        
        #plt.show()
        if logdir and img_name:
            if not os.path.exists(os.path.join(logdir, 'generated_images')): os.makedirs(os.path.join(logdir, 'generated_images'))
            fig.savefig(os.path.join(logdir, os.path.join('generated_images', img_name+'.svg')))
            
    def train_generator_gan(self, x, batch_size, noise_shape, shift_fraction=0.):
        Y_trash = np.ones(x.shape[0])
        train_datagen = ImageDataGenerator(width_shift_range=shift_fraction,
                                           height_shift_range=shift_fraction)  # shift up to 2 pixel for MNIST
        generator = train_datagen.flow(x, Y_trash, batch_size=batch_size)
        while 1:
            x_batch, _ = generator.next()
            ones = np.ones((np.shape(x_batch)[0], 1))
            zeros = np.zeros((np.shape(x_batch)[0], 1))
            noise = np.random.normal(0, 1, (np.shape(x_batch)[0],)+noise_shape)
            G_noise = self.G.model.predict(np.random.normal(0, 1, (np.shape(x_batch)[0],)+noise_shape))
            yield ({'G': noise, 'D': [x_batch, G_noise]}, {'G': np.append(ones, zeros, axis=-1), 'D': [np.append(ones, zeros, axis=-1), np.append(zeros, ones, axis=-1)]})

    def test_generator_gan(self, x, noise_shape):
        ones = np.ones((x.shape[0], 1))
        zeros = np.zeros((x.shape[0], 1))
        noise = np.random.normal(0, 1, (x.shape[0],)+ noise_shape)
        G_noise = self.G.model.predict(np.random.normal(0, 1, (x.shape[0],)+ noise_shape))
        return ({'G': noise, 'D': [x, G_noise]}, {'G': np.append(ones, zeros, axis=-1), 'D': [np.append(ones, zeros, axis=-1), np.append(zeros, ones, axis=-1)]})
    
    def fit(self, x, batch_size, epochs, disc_iters=1, gen_iters=1, callbacks=[],  grid=[10,10], imgsize=[10,10], img_range=[-1, 1], y=None, eval_scores=None, load_weights={'G':None, 'D':None}, validation_data=None,
            PlotGenetaredSamples=True, PlotModel = True, TensorBoard=True, debug=False, ModelCheckpoint= True, CSVLogger=True, logdir='./', fixed_latent=False, checkpoint_interval=None,  **kwargs):
        self.imgs = x
        cb = []
        cb += callbacks
        self.grid = grid
        self.eval_scores = eval_scores
        self.disc_iters = disc_iters
        self.gen_iters = gen_iters
        self.n_samples = x.shape[0]
        self.batch_size= batch_size
        self.target_scale = img_range
        self.imgsize = imgsize
        self.logdir = logdir
        self.PlotGenetaredSamples = PlotGenetaredSamples
        self.fixed_latent = fixed_latent
        self.checkpoint_interval = checkpoint_interval
        if CSVLogger: cb.append(cbks.CSVLogger(os.path.join(logdir, 'history.csv')))
        if TensorBoard: cb.append(cbks.TensorBoard(log_dir=os.path.join(logdir, 'tb'),
                                   batch_size=self.batch_size, histogram_freq=debug))
        
        if PlotModel:
            plot_model(self.G.sequential, to_file=os.path.join(logdir, 'gen_'+self.G.name+'.svg'), show_shapes=True)
            plot_model(self.D.sequential, to_file=os.path.join(logdir, 'disc_'+self.D.name+'.svg'), show_shapes=True)
            plot_model(self.model, to_file=os.path.join(logdir, self.name+'.svg'), show_shapes=True)
        if load_weights['G'] : 
            print("loading generator's weights - start")
            self.G.model.load_weights(load_weights['G'])
            print("loading generator's weights - end")
        if load_weights['D'] :
            print("loading discriminator's weights - start")
            self.D.model.load_weights(load_weights['D'])
            print("loading discriminator's weights - end")
        self.fit_generator(generator=self.train_generator_gan(x, batch_size, self.G.input_shape, 0.1),
                           steps_per_epoch=int(x.shape[0] / batch_size),
                           epochs=epochs,
                           validation_data=self.test_generator_gan(x, self.G.input_shape),
                           callbacks= cb, ModelCheckpoint= ModelCheckpoint, logdir=logdir,
                           **kwargs)
    def fit_generator(self,generator,
                      steps_per_epoch=None,
                      epochs=1,
                      verbose=1,
                      callbacks=None,
                      ModelCheckpoint=None,
                      validation_data=None,
                      validation_steps=None,
                      class_weight=None,
                      logdir='./',
                      max_queue_size=10,
                      workers=1,
                      use_multiprocessing=False,
                      shuffle=True,
                      initial_epoch=0):
        
        
        """See docstring for `Model.fit_generator`."""
        wait_time = 0.01  # in seconds
        epoch = initial_epoch

        do_validation = bool(validation_data)
        gan_validation_data = [x['G'] for x in validation_data]
        disc_validation_data = [x['D'] for x in validation_data]
        self.model._make_train_function()
        self.D.model._make_train_function()
        if do_validation:
            self.model._make_test_function()
            self.D.model._make_test_function()

        is_sequence = isinstance(generator, utils.data_utils.Sequence)
        if not is_sequence and use_multiprocessing and workers > 1:
            warnings.warn(
                UserWarning('Using a generator with `use_multiprocessing=True`'
                            ' and multiple workers may duplicate your data.'
                            ' Please consider using the`keras.utils.Sequence'
                            ' class.'))
        if steps_per_epoch is None:
            if is_sequence:
                steps_per_epoch = len(generator)
            else:
                raise ValueError('`steps_per_epoch=None` is only valid for a'
                                 ' generator based on the '
                                 '`keras.utils.Sequence`'
                                 ' class. Please specify `steps_per_epoch` '
                                 'or use the `keras.utils.Sequence` class.')

        # python 2 has 'next', 3 has '__next__'
        # avoid any explicit version checks
        gan_val_gen = (hasattr(gan_validation_data, 'next') or
                   hasattr(gan_validation_data, '__next__') or
                   isinstance(gan_validation_data, utils.data_utils.Sequence))
        disc_val_gen = (hasattr(disc_validation_data, 'next') or
                   hasattr(disc_validation_data, '__next__') or
                   isinstance(disc_validation_data, utils.data_utils.Sequence))
        if (gan_val_gen and not isinstance(gan_validation_data, utils.data_utils.Sequence) and
                not validation_steps) or (disc_val_gen and not isinstance(disc_validation_data, utils.data_utils.Sequence) and
                not validation_steps):
            raise ValueError('`validation_steps=None` is only valid for a'
                             ' generator based on the `keras.utils.Sequence`'
                             ' class. Please specify `validation_steps` or use'
                             ' the `keras.utils.Sequence` class.')
        
        # Prepare display labels.
        for i in range(len(self.model.metrics_names)):
            self.model.metrics_names[i] = 'cb'+str(i)+'_'+ self.model.metrics_names[i]
        for i in range(len(self.D.model.metrics_names)):
            self.D.model.metrics_names[i] = 'cb'+str(i)+'_'+ self.D.model.metrics_names[i]
        out_labels = ['gen_'+x for x in self.model.metrics_names] + ['disc_'+x for x in self.D.model.metrics_names]
                
        callback_metrics = out_labels + ['val_' + n for n in out_labels]
        
        if self.eval_scores:
            callback_metrics += ['IS_mean', 'IS_stdv', 'FID']
        # prepare callbacks
        self.model.history = cbks.History()
        self.D.model.history = cbks.History()
        _callbacks = [cbks.BaseLogger(
            stateful_metrics= ['gen_'+x for x in self.model.metrics_names] + ['disc_'+x for x in self.D.model.metrics_names])]
        if verbose:
            _callbacks.append(
                cbks.ProgbarLogger(
                    count_mode='steps',
                    stateful_metrics=['gen_'+x for x in self.model.metrics_names] + ['disc_'+x for x in self.D.model.metrics_names]))
        _callbacks += (callbacks or []) + [self.model.history] + [self.D.model.history] 
        _callbacks = cbks.CallbackList(_callbacks)

        # it's possible to callback a different model than self:
        if hasattr(self.model, 'callback_model') and self.model.callback_model:
            gan_callback_model = self.model.callback_model
        else:
            gan_callback_model = self.model
        if hasattr(self.D.model, 'callback_model') and self.D.model.callback_model:
            disc_callback_model = self.D.model.callback_model
        else:
            disc_callback_model = self.D.model
        _callbacks.set_model(gan_callback_model)
        
        _callbacks.set_params({
            'epochs': epochs,
            'steps': steps_per_epoch,
            'verbose': verbose,
            'do_validation': do_validation,
            'metrics': callback_metrics,
        })
        _callbacks.on_train_begin()
        gan_enqueuer = None
        disc_enqueuer = None
        gan_val_enqueuer = None
        disc_val_enqueuer = None

        try:
            if do_validation:
                if gan_val_gen and disc_val_gen and workers > 0:
                    gen_val_data = gan_validation_data
                    disc_val_data = disc_validation_data
                    # Create an Enqueuer that can be reused
                    if isinstance(gan_val_data, utils.data_utils.Sequence):
                        gan_val_enqueuer = OrderedEnqueuer(gan_val_data,
                                                       use_multiprocessing=use_multiprocessing)
                        gan_validation_steps = len(gan_val_data)
                    else:
                        gan_val_enqueuer = utils.data_utils.GeneratorEnqueuer(gan_val_data,
                                                         use_multiprocessing=use_multiprocessing)
                    if isinstance(disc_val_data, utils.data_utils.Sequence):
                        disc_val_enqueuer = OrderedEnqueuer(disc_val_data,
                                                       use_multiprocessing=use_multiprocessing)
                        disc_validation_steps = len(disc_val_data)
                    else:
                        disc_val_enqueuer = utils.data_utils.GeneratorEnqueuer(disc_val_data,
                                                         use_multiprocessing=use_multiprocessing)
                    gan_val_enqueuer.start(workers=workers,
                                       max_queue_size=max_queue_size)
                    gan_val_enqueuer_gen = gan_val_enqueuer.get()
                    disc_val_enqueuer.start(workers=workers,
                                       max_queue_size=max_queue_size)
                    disc_val_enqueuer_gen = disc_val_enqueuer.get()
                elif gan_val_gen or disc_val_gen:
                    if (gan_val_gen):
                        gan_val_data = gan_validation_data
                        if isinstance(gan_val_data, utils.data_utils.Sequence):
                            gan_val_enqueuer_gen = iter(gan_val_data)
                        else:
                            gan_val_enqueuer_gen = gan_val_data
                    if (disc_val_gen):
                        disc_val_data = disc_validation_data
                        if isinstance(disc_val_data, utils.data_utils.Sequence):
                            disc_val_enqueuer_gen = iter(disc_val_data)
                        else:
                            disc_val_enqueuer_gen = disc_val_data
                else:
                    # Prepare data for validation
                    if len(gan_validation_data) == 2 or len(disc_validation_data) == 2:
                        if len(gan_validation_data) == 2:
                            gan_val_x, gan_val_y = gan_validation_data
                            gan_val_sample_weight = None
                        if len(disc_validation_data) == 2:
                            disc_val_x, disc_val_y = disc_validation_data
                            disc_val_sample_weight = None
                    elif len(gan_validation_data) == 3 or len(disc_validation_data) == 3:
                        if len(gan_validation_data) == 3:
                            gan_val_x, gen_val_y, gan_val_sample_weight = gan_validation_data
                        if len(disc_validation_data) == 3:
                            disc_val_x, disc_val_y, disc_val_sample_weight = disc_validation_data
                        
                    else:
                        raise ValueError('`validation_data` should be a tuple '
                                         '`(val_x, val_y, val_sample_weight)` '
                                         'or `(val_x, val_y)`. Found: ' +
                                         str(gan_validation_data) + str(disc_validation_data))
                    gan_val_x, gan_val_y, gan_val_sample_weights = self.model._standardize_user_data(
                        gan_val_x, gan_val_y, gan_val_sample_weight)
                    disc_val_x, disc_val_y, disc_val_sample_weights = self.D.model._standardize_user_data(
                        disc_val_x , disc_val_y,
                        disc_val_sample_weight)
                    gan_val_data = gan_val_x + gan_val_y + gan_val_sample_weights
                    disc_val_data = disc_val_x + disc_val_y + disc_val_sample_weights
                    if self.model.uses_learning_phase and not isinstance(K.learning_phase(),
                                                                    int):
                        gan_val_data += [0.]
                    if self.D.model.uses_learning_phase and not isinstance(K.learning_phase(),
                                                                    int):
                        disc_val_data += [0.]
                    for cbk in _callbacks:
                        cbk.validation_data = gan_val_data
                        cbk.validation_data = disc_val_data

            if workers > 0:
                if is_sequence:
                    enqueuer = OrderedEnqueuer(
                        generator,
                        use_multiprocessing=use_multiprocessing,
                        shuffle=shuffle)
                else:
                    enqueuer = utils.data_utils.GeneratorEnqueuer(
                        generator,
                        use_multiprocessing=use_multiprocessing,
                        wait_time=wait_time)
                enqueuer.start(workers=workers, max_queue_size=max_queue_size)
                output_generator = enqueuer.get()
            else:
                if is_sequence:
                    output_generator = iter(generator)
                else:
                    output_generator = generator
            gan_callback_model.stop_training = False
            disc_callback_model.stop_training = False
                      
            def generator_next():
                generator_output = next(output_generator)
                gan_generator_output = [x['G'] for x in generator_output]
                disc_generator_output = [x['D'] for x in generator_output]
                if not hasattr(gan_generator_output, '__len__'):
                    raise ValueError('Output of generator should be '
                                     'a tuple `(x, y, sample_weight)` '
                                     'or `(x, y)`. Found: ' +
                                     str(gan_generator_output))
                if not hasattr(disc_generator_output, '__len__'):
                    raise ValueError('Output of generator should be '
                                     'a tuple `(x, y, sample_weight)` '
                                     'or `(x, y)`. Found: ' +
                                     str(disc_generator_output))

                if len(gan_generator_output) == 2:
                    gan_x, gan_y = gan_generator_output
                    gan_sample_weight = None
                elif len(gan_generator_output) == 3:
                    gan_x, gan_y, gan_sample_weight = gan_generator_output
                else:
                    raise ValueError('Output of generator should be '
                                     'a tuple `(x, y, sample_weight)` '
                                     'or `(x, y)`. Found: ' +
                                     str(gan_generator_output))
                if len(disc_generator_output) == 2:
                    disc_x, disc_y = disc_generator_output
                    disc_sample_weight = None
                elif len(disc_generator_output) == 3:
                    disc_x, disc_y, disc_sample_weight = disc_generator_output
                else:
                    raise ValueError('Output of generator should be '
                                     'a tuple `(x, y, sample_weight)` '
                                     'or `(x, y)`. Found: ' +
                                     str(disc_generator_output))
                return gan_x, gan_y, gan_sample_weight, disc_x, disc_y, disc_sample_weight
            # Construct epoch logs.
            epoch_logs = {}
            if self.fixed_latent:
                noise= np.random.normal(0, 1, (self.grid[0]*self.grid[1],100))
            
            while epoch < epochs:
                for m in self.model.stateful_metric_functions:
                    m.reset_states()
                for m in self.D.model.stateful_metric_functions:
                    m.reset_states()
                _callbacks.on_epoch_begin(epoch)
                steps_done = 0
                batch_index = 0
                while steps_done < steps_per_epoch:
                    gan_x, gan_y, gan_sample_weight, disc_x, disc_y, disc_sample_weight = generator_next()
                    # build batch logs
                    batch_logs = {}
                    if gan_x is None or len(gan_x) == 0:
                        # Handle data tensors support when no input given
                        # step-size = 1 for data tensors
                        batch_size = 1
                    elif isinstance(gan_x, list):
                        batch_size = gan_x[0].shape[0]
                    elif isinstance(gan_x, dict):
                        batch_size = list(gan_x.values())[0].shape[0]
                    else:
                        batch_size = gan_x.shape[0]
                    batch_logs['batch'] = batch_index
                    batch_logs['size'] = batch_size
                    _callbacks.on_batch_begin(batch_index, batch_logs)
                    for _ in range(self.disc_iters):
                        gan_x, gan_y, gan_sample_weight, disc_x, disc_y, disc_sample_weight = generator_next()
                        G_z = self.G.model.predict(gan_x)
                        disc_outs= self.D.model.train_on_batch(disc_x, disc_y, sample_weight=disc_sample_weight, class_weight=class_weight)
                    self.D.sequential.trainable= False
                    self.D.model.trainable= False
                    for _ in range(self.gen_iters):
                        gan_x, gan_y, gan_sample_weight, disc_x, disc_y, disc_sample_weight = generator_next()
                        # train generator: z->D(G(z)) real with discriminator is not trainable
                        gan_outs = self.model.train_on_batch(gan_x, gan_y, sample_weight=gan_sample_weight, class_weight=class_weight)
                    self.D.sequential.trainable= True
                    self.D.model.trainable= True       
                    
                    disc_outs = utils.generic_utils.to_list(disc_outs)
                    gan_outs = utils.generic_utils.to_list(gan_outs)
                    
                    for l, o in zip(out_labels, gan_outs+disc_outs):
                        batch_logs[l] = o  
                    
                    _callbacks.on_batch_end(batch_index, batch_logs)

                    batch_index += 1
                    steps_done += 1
                    
                    # Epoch finished.
                    if steps_done >= steps_per_epoch and do_validation:
                        if gan_val_gen:
                            gan_val_outs = self.model.evaluate_generator(
                                gan_val_enqueuer_gen,
                                gan_validation_steps,
                                workers=0)
                        if disc_val_gen:
                            disc_val_outs = self.D.model.evaluate_generator(
                                disc_val_enqueuer_gen,
                                disc_validation_steps,
                                workers=0)
                        else:
                            # No need for try/except because
                            # data has already been validated.
                            gan_val_outs = self.model.evaluate(
                                gan_val_x, gan_val_y,
                                batch_size=batch_size,
                                sample_weight=gan_val_sample_weights,
                                verbose=0)
                            disc_val_outs = self.D.model.evaluate(
                                disc_val_x, disc_val_y,
                                batch_size=batch_size,
                                sample_weight=disc_val_sample_weights,
                                verbose=0)
                        if self.eval_scores:
                            is_mean, is_stdv, fid = self.get_eval_scores(splits=10, n_samples=self.n_samples//100)
                            epoch_logs['IS_mean']=is_mean
                            epoch_logs['IS_stdv']=is_stdv
                            epoch_logs['FID']=fid
                        gan_val_outs = utils.generic_utils.to_list(gan_val_outs)
                        disc_val_outs = utils.generic_utils.to_list(disc_val_outs)
                        # Same labels assumed.
                        for l, o in zip(out_labels, gan_val_outs+disc_val_outs):
                            epoch_logs['val_' + l] = o
                    if gan_callback_model.stop_training:
                        break
                    if disc_callback_model.stop_training:
                        break
                        
                    if self.checkpoint_interval:
                        bool_checkpt = (steps_done % self.checkpoint_interval) == 0
                    else:
                        bool_checkpt = False
                    if (steps_done >= steps_per_epoch or bool_checkpt) and self.PlotGenetaredSamples:
                        if self.fixed_latent is None:
                            noise = np.random.normal(0, 1, (self.grid[0]*self.grid[1],100))
                        G_z= self.G.model.predict(noise)
                        self.plot_generated_samples(G_z, grid=self.grid , imgsize=self.imgsize, img_range=self.target_scale,
                                        cmap=(None if (self.G.output_shape[-1]) == 3 else 'gray'), logdir=self.logdir, img_name='epoch%d'%(epoch))
                _callbacks.on_epoch_end(epoch, epoch_logs)
                epoch += 1
                
                if ModelCheckpoint:
                    if not os.path.exists(os.path.join(logdir, 'models')): os.makedirs(os.path.join(logdir, 'models'))
                    self.G.model.save_weights(os.path.join(logdir, 'models/gen_'+self.G.name+'.h5'))
                    self.D.model.save_weights(os.path.join(logdir, 'models/disc_'+self.D.name+'.h5'))
                
                if gan_callback_model.stop_training:
                    break
                if disc_callback_model.stop_training:
                    break
                
        finally:
            try:
                if gan_enqueuer is not None:
                    gan_enqueuer.stop()
                if disc_enqueuer is not None:
                    disc_enqueuer.stop()
            finally:
                if gan_val_enqueuer is not None:
                    gan_val_enqueuer.stop()
                if disc_val_enqueuer is not None:
                    disc_val_enqueuer.stop()

        _callbacks.on_train_end()
        return self.model.history, self.D.model.history

#### DCGAN

In [None]:
class DCGAN(GAN):
    def __init__ (self, name='DCGAN', **kwargs):
        super(DCGAN, self).__init__(name=name, **kwargs)
        self.G = ConvNet_Up()
        self.D = ConvNet()
        
        
    def build_compile(self, input_shape, output_shape, loss, optimizer, metrics={'D':None}, pretrained_generator=None, pretrained_discriminator=None, **kwargs):
        gen_io_shape, disc_io_shape = self.split_kwargs(input_shape= input_shape, output_shape=output_shape)
        self.G.build(**gen_io_shape, len_io=1)
        self.D.build(**disc_io_shape, len_io=2)
        gen_compiler, disc_compiler = self.split_kwargs(loss= loss, optimizer=optimizer, metrics=metrics, **kwargs)
        self.G.compile(**gen_compiler)
        self.D.compile(**disc_compiler)
        
        self.D.sequential.trainable = False
        self.D.model.trainable = False
        
        z = layers.Input(shape=self.G.input_shape)
        G_z = self.G.model(z)
        D_G_z = self.D.sequential(G_z)
        self.model = models.Model(z, D_G_z)
        if len(device_lib.list_local_devices()) >= 3:
            self.model = multi_gpu_model(self.model, gpus=len(device_lib.list_local_devices())-1)
        print("************************************GENERATOR_IN_GAN*************************************")
        self.model.summary()
        self.model.compile(**gen_compiler)
        return self.model

#### WGAN-GP

In [None]:
class WGAN_GP(GAN):
    def __init__ (self, batch_size, name='WGAN_GP', **kwargs):
        super(WGAN_GP, self).__init__(name=name, **kwargs)
        self.batch_size = batch_size
        self.G = ConvNet_Up();
        self.D = ConvNet_Critic(batch_size);
        
        
    def build_compile(self, input_shape, output_shape, loss, optimizer, metrics={'D':None}, pretrained_generator=None, pretrained_discriminator=None, **kwargs):
        gen_io_shape, disc_io_shape = self.split_kwargs(input_shape= input_shape, output_shape=output_shape)
        self.G.build(**gen_io_shape, len_io=1)
        self.D.build(**disc_io_shape, len_io=2)
        gen_compiler, disc_compiler = self.split_kwargs(loss= loss, optimizer=optimizer, metrics=metrics, **kwargs)
        self.G.compile(**gen_compiler)
        self.D.compile(**disc_compiler)
        
        self.D.sequential.trainable = False
        self.D.model.trainable = False
        self.G.trainbale = True
        z = layers.Input(shape=self.G.input_shape)
        G_z = self.G.model(z)
        D_G_z = self.D.sequential(G_z)
        self.model = models.Model(z, D_G_z)
        if len(device_lib.list_local_devices()) >= 3:
            self.model = multi_gpu_model(self.model, gpus=len(device_lib.list_local_devices())-1)
        print("************************************GENERATOR_IN_GAN*************************************")
        self.model.summary()
        self.model.compile(**gen_compiler)
        return self.model
    
    def train_generator_gan(self, x, batch_size, noise_shape, shift_fraction=0.):
        Y_trash = np.ones(x.shape[0])
        train_datagen = ImageDataGenerator(width_shift_range=shift_fraction,
                                           height_shift_range=shift_fraction)  # shift up to 2 pixel for MNIST
        generator = train_datagen.flow(x, Y_trash, batch_size=batch_size)
        while 1:
            x_batch, _ = generator.next()
            ones = np.ones((np.shape(x_batch)[0], 1))
            zeros = np.zeros((np.shape(x_batch)[0], 1))
            noise = np.random.normal(0, 1, (np.shape(x_batch)[0],)+noise_shape)
            G_noise = self.G.model.predict(np.random.normal(0, 1, (np.shape(x_batch)[0],)+noise_shape))
            yield ({'G': noise, 'D': [x_batch, G_noise]}, {'G': np.append(-ones, ones, axis=-1), 'D': [np.append(-ones, ones, axis=-1), np.append(ones, -ones, axis=-1), np.append(zeros, zeros, axis=-1)]})

    def test_generator_gan(self, x, noise_shape):
        ones = np.ones((x.shape[0], 1))
        zeros = np.zeros((x.shape[0], 1))
        noise = np.random.normal(0, 1, (x.shape[0],)+ noise_shape)
        G_noise = self.G.model.predict(np.random.normal(0, 1, (x.shape[0],)+ noise_shape))
        return ({'G': noise, 'D': [x, G_noise]}, {'G': np.append(-ones, ones, axis=-1), 'D': [np.append(-ones, ones, axis=-1), np.append(ones, -ones, axis=-1), np.append(zeros, zeros, axis=-1)]})
    

#### VCapsGAN

In [None]:
class VCapsGAN(GAN):
    def __init__ (self, name='VCapsGAN', **kwargs):
        super(VCapsGAN, self).__init__(name=name, **kwargs)
        self.G = ConvNet_Up();
        self.D = VCapsNet();
        
    
    def build_compile(self, input_shape, output_shape, loss, optimizer, L1_n, L2_n, L2_dim, L3_dim, routing,  metrics={'D':None}, pretrained_generator=None, pretrained_discriminator=None, **kwargs):
        gen_build_args, disc_build_args = self.split_kwargs(input_shape= input_shape, output_shape=output_shape,
                                                           L1_n=L1_n, L2_n=L2_n, L2_dim=L2_dim, L3_dim=L3_dim, routing=routing)
        self.G.build(**gen_build_args, len_io=1)
        self.D.build(**disc_build_args, len_io=2)
        gen_compile_args, disc_compile_args = self.split_kwargs(loss= loss, optimizer=optimizer, metrics=metrics, **kwargs)
        self.D.compile(**disc_compile_args)
        self.G.compile(**gen_compile_args)
        
        self.D.sequential.trainable = False
        self.D.model.trainable = False
        
        z = layers.Input(shape=self.G.input_shape)
        G_z = self.G.model(z)
        D_G_z = self.D.sequential(G_z)
        self.model = models.Model(z, D_G_z)
        if len(device_lib.list_local_devices()) >= 3:
            self.model = multi_gpu_model(self.model, gpus=len(device_lib.list_local_devices())-1)
        print("************************************GENERATOR_IN_GAN*************************************")
        self.model.summary()
        self.model.compile(**gen_compile_args)
        return self.model

#### MCapsGAN

In [None]:
class MCapsGAN(GAN):
    def __init__ (self, batch_size, name='MCapsGAN', **kwargs):
        super(MCapsGAN, self).__init__(name=name, **kwargs)
        self.G = ConvNet_Up();
        self.D = MCapsNet(batch_size=batch_size);
        
    def build_compile(self, input_shape, output_shape, batch_size, n_samples, loss, optimizer, L1_n, L2_n, L3_n, L4_n, pose_shape={'D':[4,4]}, routing={'D':3}, decoder={'D':False},  metrics={'D':None}, pretrained_generator=None, pretrained_discriminator=None, **kwargs):
        gen_build_args, disc_build_args = self.split_kwargs(input_shape= input_shape, output_shape=output_shape,
                                                           L1_n=L1_n, L2_n=L2_n, L3_n=L3_n, L4_n=L4_n, pose_shape=pose_shape, routing=routing, decoder=decoder)
        self.G.build(**gen_build_args, len_io=1)
        self.D.build(**disc_build_args, len_io=2)
        gen_compile_args, disc_compile_args = self.split_kwargs(loss= loss, optimizer=optimizer, metrics=metrics, **kwargs)
        self.D.compile(batch_size=batch_size, n_samples=n_samples, **disc_compile_args)
        self.G.compile(**gen_compile_args)
        
        
        self.D.sequential.trainable = False
        self.D.model.trainable = False
        
        z = layers.Input(shape=self.G.input_shape)
        G_z = self.G.model(z)
        D_G_z = self.D.sequential(G_z)
        self.model = models.Model(input=z, output=D_G_z)
        self.model = models.Model(z, D_G_z)
        if len(device_lib.list_local_devices()) >= 3:
            self.model = multi_gpu_model(self.model, gpus=len(device_lib.list_local_devices())-1)
        print("************************************GENERATOR_IN_GAN*************************************")
        self.model.summary()
        self.model.compile(**gen_compile_args)
        return self.model

## Training the network

### Callbacks

Checkpoint for every iteration => Training would be slow and checkpoint data would be huge

In [None]:
class TrainValTensorBoard(cbks.TensorBoard):
        def __init__(self, metrics=['acc', 'loss'] , log_dir = "./logs", **kwargs):
            # Make the original `TensorBoard` log to a subdirectory 'training'
            self.training_log_dir = os.path.join(log_dir, 'training')
            super(TrainValTensorBoard, self).__init__(self.training_log_dir, **kwargs)
            # Log the validation metrics to a separate subdirectory
            self.val_log_dir = os.path.join(log_dir, 'validation')
            self.step = 0
            self.metrics = metrics
            self.batch_writer = tf.summary.FileWriter(self.training_log_dir)

        def set_model(self, model):
            # Setup writer for validation metrics
            self.val_writer = tf.summary.FileWriter(self.val_log_dir)
            super(TrainValTensorBoard, self).set_model(model)

        def on_epoch_end(self, epoch, logs=None):
            checkpt_dict= TRAIN["param"]["checkpoint"]
            # save trained models
            if(checkpt_dict["models"]["save"]):
                models_dir=os.path.join(checkpt_dict["logdir"], 'models')
                DISCRIMINATOR["train"].save_weights(os.path.join(models_dir, "discriminator.h5"))
                if len(TRAIN["models_to_train"]) == 2:
                    GENERATOR["train"].save_weights(os.path.join(models_dir, "generator.h5"))
            logs = logs or {}
            val_logs = {k.replace('val_', ''): v for k, v in logs.items() if k.startswith('val_')}
            for name, value in val_logs.items():
                summary = tf.Summary()
                summary_value = summary.value.add()
                summary_value.simple_value = value.item()
                summary_value.tag = name
                self.val_writer.add_summary(summary, self.step)
            self.val_writer.flush()

            logs = {k: v for k, v in logs.items() if not k.startswith('val_')}
            super(TrainValTensorBoard, self).on_epoch_end(epoch, logs)

        def on_batch_end(self, batch, logs={}):
            for name, value in logs.items():
                if name in self.metrics:
                    summary = tf.Summary()
                    summary_value = summary.value.add()
                    summary_value.simple_value = value.item()
                    summary_value.tag = name
                    self.batch_writer.add_summary(summary, self.step) 
            self.batch_writer.flush()
            self.step += 1
            super(TrainValTensorBoard, self).on_batch_end(logs)

        def on_train_end(self, logs=None):
            super(TrainValTensorBoard, self).on_train_end(logs)
            self.val_writer.close()

### Train

In [None]:
disc_dict = DISCRIMINATOR["param"]
disc_net_dict = disc_dict[disc_dict["topology"]]
train_dict = TRAIN["param"]
checkpt_dict= train_dict["checkpoint"]

f = open(os.path.join(checkpt_dict["logdir"], 'parameters.txt'),"w")
f.write("DATASET:"); f.write("\n"); f.write(str(DATASET)); f.write("\n \n");
f.write("GENERATOR:"); f.write("\n"); f.write(str(GENERATOR)); f.write("\n \n");
f.write("DISCRIMINATOR:"); f.write("\n"); f.write(str(DISCRIMINATOR)); f.write("\n \n");
f.write("COMBINED:"); f.write("\n"); f.write(str(COMBINED)); f.write("\n \n");
f.write("TRAIN:"); f.write("\n"); f.write(str(TRAIN));
f.close()

if len(TRAIN["models_to_train"]) == 1:
    if disc_dict["topology"] is "ConvNet":
        network =  ConvNet()
        network.build(input_shape= disc_dict["inputs_shape"], output_shape= disc_dict["output_shape"])
        network.compile(optimizer=disc_net_dict["optimizer"], loss=network.loss_fn, metrics=['accuracy'])

    if disc_dict["topology"] is "VCapsNet":
        network =  VCapsNet()
        if disc_net_dict["decoder"]:
            network.build(input_shape= disc_dict["inputs_shape"], output_shape= disc_dict["output_shape"],
                          L1_n=disc_net_dict["L1_n"],
                          L2_n=disc_net_dict["L2_n"],
                          L2_dim=disc_net_dict["L2_dim"],
                          L3_dim=disc_net_dict["L3_dim"],
                          routing=disc_net_dict["routing_iters"],
                          decoder=disc_net_dict["decoder"],
                          L4_n=disc_net_dict["L4_n"],
                          L5_n=disc_net_dict["L5_n"])
            network.compile(optimizer=disc_net_dict["optimizer"], loss=[network.loss_fn, 'mse'], metrics=['accuracy'])
        else:
            network.build(input_shape= disc_dict["inputs_shape"], output_shape= disc_dict["output_shape"],
                          L1_n=disc_net_dict["L1_n"],
                          L2_n=disc_net_dict["L2_n"],
                          L2_dim=disc_net_dict["L2_dim"],
                          L3_dim=disc_net_dict["L3_dim"],
                          routing=disc_net_dict["routing_iters"])
            network.compile(optimizer=disc_net_dict["optimizer"], loss=network.loss_fn, metrics=['accuracy'])
        

    if disc_dict["topology"] is "MCapsNet":
        network =  MCapsNet(batch_size=train_dict["batch_size"])
        network.build(input_shape= disc_dict["inputs_shape"], output_shape= disc_dict["output_shape"],
                      L1_n=disc_net_dict["L1_n"],
                      L2_n=disc_net_dict["L2_n"],
                      L3_n=disc_net_dict["L3_n"],
                      L4_n=disc_net_dict["L4_n"],
                      routing=disc_net_dict["routing_iters"],
                      pose_shape=disc_net_dict["pose_shape"])
        network.compile(optimizer=disc_net_dict["optimizer"], batch_size=train_dict["batch_size"], n_samples=DATASET['train'].labels.shape[0], loss=network.loss_fn, metrics=['accuracy'])
    network.fit(x=DATASET['train'].imgs, y=DATASET['train'].labels, batch_size=train_dict["batch_size"], epochs=train_dict["epochs"],
               validation_data=[DATASET['test'].imgs, DATASET['test'].labels],
               logdir=checkpt_dict["logdir"], TensorBoard=False)

if len(TRAIN["models_to_train"]) == 2:
    gen_dict = GENERATOR["param"]
    disc_dict = DISCRIMINATOR["param"]
    gan_dict = COMBINED["param"]
    if gan_dict["topology"] is "DCGAN":
        gen_net_dict = gen_dict["DeConvNet"]
        disc_net_dict = disc_dict["ConvNet"]
        network =  DCGAN()
        network.build_compile(input_shape={'G':gen_dict["inputs_shape"], 'D':disc_dict["inputs_shape"]},
                                     output_shape={'G':gen_dict["output_shape"], 'D':disc_dict["output_shape"]},
                     loss={'G':network.D.loss_fn, 'D':[network.D.loss_fn, network.D.loss_fn]},
                    optimizer={'G':gen_net_dict["optimizer"], 'D':disc_net_dict["optimizer"]},
                              loss_weights={'D': [1, 1]}, metrics={'D':['accuracy']})
        gen_history, disc_history = network.fit(x=DATASET['train'].imgs, batch_size=train_dict["batch_size"], epochs=train_dict["epochs"],
               validation_data=[DATASET['test'].imgs, DATASET['test'].labels],
               logdir=checkpt_dict["logdir"], fixed_latent = True, checkpoint_interval= checkpt_dict["interval"],
                                               disc_iters = disc_net_dict["iters"], gen_iters = gen_net_dict["iters"],
                                               load_weights={'G':TRAIN["trained_models"]['G'], 'D':TRAIN["trained_models"]['D']})
    if gan_dict["topology"] is "WGAN_GP":
        gen_net_dict = gen_dict["DeConvNet"]
        disc_net_dict = disc_dict["Critic"]
        network =  WGAN_GP(batch_size=train_dict["batch_size"])
        network.build_compile(input_shape={'G':gen_dict["inputs_shape"], 'D':disc_dict["inputs_shape"]},
                                     output_shape={'G':gen_dict["output_shape"], 'D':disc_dict["output_shape"]},
                     loss={'G':network.D.loss_fn, 'D':[network.D.loss_fn, network.D.loss_fn]},
                    optimizer={'G':disc_net_dict["optimizer"], 'D':disc_net_dict["optimizer"]},
                              loss_weights={'D': [1, 1, 10]}, metrics={'D':['accuracy']})
        gen_history, disc_history = network.fit(x=DATASET['train'].imgs, batch_size=train_dict["batch_size"], epochs=train_dict["epochs"],
               validation_data=[DATASET['test'].imgs, DATASET['test'].labels],
               logdir=checkpt_dict["logdir"], fixed_latent = True, checkpoint_interval= checkpt_dict["interval"],
                                               disc_iters = disc_net_dict["iters"], gen_iters = gen_net_dict["iters"],
                                               load_weights={'G':TRAIN["trained_models"]['G'], 'D':TRAIN["trained_models"]['D']})
        
       
    if gan_dict["topology"] is "VCapsGAN":
        gen_net_dict = gen_dict["DeConvNet"]
        disc_net_dict = disc_dict["VCapsNet"]
        network = VCapsGAN()
        network.build_compile(input_shape={'G':gen_dict["inputs_shape"], 'D':disc_dict["inputs_shape"]},
                                     output_shape={'G':gen_dict["output_shape"], 'D':disc_dict["output_shape"]},
                     loss={'G':network.D.loss_fn, 'D':[network.D.loss_fn, network.D.loss_fn]},
                                  optimizer={'G':gen_net_dict["optimizer"], 'D':disc_net_dict["optimizer"]}, metrics={'D':['accuracy']},
                             L1_n={'D':disc_net_dict["L1_n"]}, L2_n={'D':disc_net_dict["L2_n"]}, L2_dim={'D':disc_net_dict["L2_dim"]}, L3_dim={'D':disc_net_dict["L3_dim"]},
                             routing = {'D':disc_net_dict["routing_iters"]})
        
        gen_history, disc_history = network.fit(x=DATASET['train'].imgs, batch_size=train_dict["batch_size"], epochs=train_dict["epochs"],
               validation_data=[DATASET['test'].imgs, DATASET['test'].labels],
               logdir=checkpt_dict["logdir"], fixed_latent = True, checkpoint_interval= checkpt_dict["interval"], eval_scores=True,
                                               load_weights={'G':TRAIN["trained_models"]['G'], 'D':TRAIN["trained_models"]['D']})
    
    if gan_dict["topology"] is "MCapsGAN":
        gen_net_dict = gen_dict["DeConvNet"]
        disc_net_dict = disc_dict["MCapsNet"]
        network = MCapsGAN(batch_size=train_dict["batch_size"])
        network.build_compile(input_shape={'G':gen_dict["inputs_shape"], 'D':disc_dict["inputs_shape"]},
                                     output_shape={'G':gen_dict["output_shape"], 'D':disc_dict["output_shape"]},
                     loss={'G':network.D.loss_fn, 'D':[network.D.loss_fn, network.D.loss_fn]},
                                  optimizer={'G':gen_net_dict["optimizer"], 'D':disc_net_dict["optimizer"]}, metrics={'D':['accuracy']},
                                  L1_n={'D':disc_net_dict["L1_n"]},
                                  L2_n={'D':disc_net_dict["L2_n"]},
                                  L3_n={'D':disc_net_dict["L3_n"]},
                                  L4_n={'D':disc_net_dict["L4_n"]},
                                  routing={'D':disc_net_dict["routing_iters"]},
                                  pose_shape={'D':disc_net_dict["pose_shape"]},
                             n_samples = DATASET['train'].imgs.shape[0],
                             batch_size= train_dict["batch_size"])
        
        gen_history, disc_history = network.fit(x=DATASET['train'].imgs, batch_size=train_dict["batch_size"], epochs=train_dict["epochs"],
               validation_data=[DATASET['test'].imgs, DATASET['test'].labels],
               logdir=checkpt_dict["logdir"], fixed_latent = True, checkpoint_interval= checkpt_dict["interval"],
                                                load_weights={'G':TRAIN["trained_models"]['G'], 'D':TRAIN["trained_models"]['D']})

## Evaluation

Print the Inception Score and the FID after the training and plot 100 generated samples

### Inception Score & Fréchet Inception Distance

In [None]:
score = network.get_eval_scores(splits=10, n_samples=1000)
print("IS: mean {}, stdv {} \n FID: {}".format(score[0], score[1], score[2]))

### Generated Image Quality

In [None]:
G_z= network.G.model.predict(np.random.normal(0, 1, (100,100)))
network.plot_generated_samples(G_z, grid=network.grid , imgsize=network.imgsize, img_range=network.target_scale,
                               cmap=(None if (network.G.output_shape[-1]) == 3 else 'gray'), logdir=network.logdir, img_name='epoch')

### Additional visualizations for the discriminator as a simple classifier

#### False predicted samples

In [None]:
if len(TRAIN["models_to_train"]) != 2:
    def show_false_predictions():
        probabilities = DISCRIMINATOR['train'].predict(x=DATASET['test'].imgs)
        predictions = probabilities.argmax(axis=-1)
        labels = DATASET['test'].labels.argmax(axis=-1)
        for i in range(len(probabilities)):
            if predictions[i] != labels[i]:
                print("prediction {} true {}".format(predictions[i], labels[i]))
                if len(DATASET['test'].imgs.shape[1:]) == 3:
                    if DATASET['test'].imgs.shape[1:][2] == 1:
                        plt.imshow(np.asarray(DATASET['test'].imgs[i][:, :, 0]), cmap='gray')
                    else:
                        plt.imshow(np.asarray(DATASET['test'].imgs[i][:, :, :]), cmap=None)
                else:
                    plt.imshow(np.asarray(DATASET['test'].imgs[i][:, :]), cmap='gray')
                plt.show()
    #show_false_predictions()