# Generative adversarial network

In [4]:
import numpy as np
import pandas as pd
from typing import Union, Tuple

import time
import copy
import matplotlib.pyplot as plt
from abc import ABC, abstractmethod

import shutil
import os
import pickle
from zipfile import ZipFile

![SegmentLocal](models/gan/mnist.gif "segment")

In [5]:
data_path = './data/digit_data/'

if os.path.exists(data_path):
    shutil.rmtree(data_path)
    
with ZipFile('./data/digit-recognizer.zip') as f:
    f.extractall(data_path)
    
train_set = pd.read_csv('./data/digit_data/train.csv')

shutil.rmtree(data_path)
train_set.head()

Unnamed: 0,label,pixel0,pixel1,pixel2,pixel3,pixel4,pixel5,pixel6,pixel7,pixel8,...,pixel774,pixel775,pixel776,pixel777,pixel778,pixel779,pixel780,pixel781,pixel782,pixel783
0,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,4,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [6]:
train_data = train_set.drop('label', axis=1).values
train_data = (train_data.astype(float) - 127.5) / 127.5 # normalize to scale [-1, 1]
train_data.shape

(42000, 784)

In [7]:
def log_loss(y_true: np.array, a_pred: np.array) -> float:
    """
    LogLoss for binary classification tasks
    :param y_true: true values (
    :param a_pred: predicted probabilities of each class [0, 1] (after sigmoid function or others)
    :return: LogLoss
    """
    return np.mean(-y_true * np.log(a_pred+1e-15) - (1 - y_true) * np.log(1 - a_pred+1e-15))


def log_loss_derivative(y_true: np.array, a_pred: np.array) -> np.array:
    """
    LogLoss derivative for binary classification tasks
    :param y_true: true values (true labels, i.e. [0, 1, 1, 0...])
    :param a_pred: predicted probabilities of each class [0, 1] (after sigmoid function or others)
    :return: np.array with derivatives
    """
    eps = 1e-15  # to avoid zero division
    return (-y_true / (a_pred + eps) + (1 - y_true) / (1 - a_pred + eps)) / len(y_true)

In [8]:
def sigmoid(z: Union[np.array, float, int, list]) -> Union[np.array, float]:
    """
    Sigmoid function
    """
    return 1 / (1 + np.exp(-z))


def sigmoid_derivative(z: Union[np.array, float, int, list]) -> Union[np.array, float]:
    """
    Sigmoid function derivative
    """
    s = 1 / (1 + np.exp(-z))
    return s * (1 - s)

def tanh(z: Union[np.array, float, int, list]) -> Union[np.array, float]:
    """
    Tanh function
    """
    return np.tanh(z)

def tanh_derivative(z: Union[np.array, float, int, list]) -> Union[np.array, float]:
    """
    Tanh function derivative
    """
    return 1 - np.tanh(z) ** 2

def leaky_relu(z: Union[np.array, float, int, list], alpha=0.2) -> np.array:
    """
    Leaky ReLU function
    """
    return np.where(z >= 0, z, alpha * z)

def leaky_relu_derivative(z: Union[np.array, float, int, list], alpha=0.2) -> np.array:
    """
    Leaky ReLU function derivative
    """
    return np.where(z >= 0, 1, alpha)    

In [9]:
class BaseOptimizer(ABC):
    @abstractmethod
    def __init__(self) -> None:
        pass

    @abstractmethod
    def set_weight(self, weight: np.array) -> None:
        pass

    @abstractmethod
    def step(self, grad: np.array) -> np.array:
        pass
    

class ADAM(BaseOptimizer):
    """
    Implements Adam algorithm.
    learning_rate (float, optional) – learning rate (default: 1e-3)
    beta1, beta2 (Tuple[float, float], optional) –
    coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999))
    eps (float, optional) – term added to the denominator to improve numerical stability (default: 1e-8)
    """

    def __init__(self, beta1: float = 0.9, beta2: float = 0.999, eps: float = 1e-8,
                 learning_rate: float = 3e-4, weight_decay: float = 0) -> None:
        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay

        self.EMA1 = None
        self.EMA2 = None

        self.weight = None

    def set_weight(self, weight: np.array) -> None:
        self.weight = weight.copy()
        self.EMA1 = np.zeros(shape=self.weight.shape)
        self.EMA2 = np.zeros(shape=self.weight.shape)

    def step(self, grad: np.array) -> np.array:
        assert self.weight is not None, 'You should set the weight'
        grad = grad.copy() + self.weight_decay * self.weight
        self.EMA1 = (1 - self.beta1) * grad + self.beta1 * self.EMA1
        self.EMA2 = (1 - self.beta2) * grad ** 2 + self.beta2 * self.EMA2
        self.weight -= self.learning_rate * self.EMA1 / (np.sqrt(self.EMA2) + self.eps)

        return self.weight.copy()

In [10]:
class BaseLayer(ABC):
    @abstractmethod
    def __init__(self) -> None:
        pass

    def __call__(self, x: np.array, grad: bool = True) -> np.array:
        return self.forward(x, grad)

    @abstractmethod
    def forward(self, x: np.array, grad: bool = True) -> np.array:
        pass

    @abstractmethod
    def backward(self, output_error: np.array) -> np.array:
        pass


class Linear(BaseLayer):
    """
    Linear class permorms ordinary FC layer in neural networks
    Parameters:
    n_input - size of input neurons
    n_output - size of output neurons
    Methods:
    set_optimizer(optimizer) - is used for setting an optimizer for gradient descent
    forward(x) - performs forward pass of the layer
    backward(output_error, learning_rate) - performs backward pass of the layer
    """

    def __init__(self, n_input: int, n_output: int) -> None:
        super().__init__()
        self.input = None
        self.n_input = n_input
        self.n_output = n_output
        
        limit = 1 / np.sqrt(self.n_input) 
        self.w = np.random.uniform(-limit, limit, size=(n_input, n_output))
        self.b = np.zeros(shape=(1, n_output))

        self.w_optimizer = None
        self.b_optimizer = None

    def set_optimizer(self, optimizer) -> None:
        self.w_optimizer = copy.copy(optimizer)
        self.b_optimizer = copy.copy(optimizer)

        self.w_optimizer.set_weight(self.w)
        self.b_optimizer.set_weight(self.b)

    def forward(self, x: np.array, grad: bool = True) -> np.array:
        self.input = x
        return x.dot(self.w) + self.b

    def backward(self, output_error: np.array) -> np.array:
        assert self.w_optimizer is not None and self.b_optimizer is not None, 'You should set an optimizer'
        w_grad = self.input.T.dot(output_error)
        b_grad = np.ones((1, len(output_error))).dot(output_error)
        input_error = output_error.dot(self.w.T)

        self.w = self.w_optimizer.step(w_grad)
        self.b = self.b_optimizer.step(b_grad)
        return input_error


class Activation(BaseLayer):
    """
    Activation class is used for activation function of the FC layer
    Params:
    activation_function - activation function (e.g. sigmoid, RElU, tanh)
    activation_derivative - derivative of the activation function
    Methods:
    forward(x) - performs forward pass of the layer
    backward(output_error, learning_rate) - performs backward pass of the layer
    """

    def __init__(self, activation_function: callable, activation_derivative: callable) -> None:
        super().__init__()
        self.input = None
        self.activation = activation_function
        self.derivative = activation_derivative

    def forward(self, x: np.array, grad: bool = True) -> np.array:
        self.input = x
        return self.activation(x)

    def backward(self, output_error: np.array) -> np.array:
        return output_error * self.derivative(self.input)

In [16]:
class MnistGenerator:
    '''
    Generator class.
    Generates fake MNIST images of size 28*28 from Normal noise (0, 1) of size 100
    '''
  
    def __init__(self):
        self.layers = [
            Linear(100, 256),
            Activation(leaky_relu, leaky_relu_derivative),
            Linear(256, 512),
            Activation(leaky_relu, leaky_relu_derivative),
            Linear(512, 1024),
            Activation(leaky_relu, leaky_relu_derivative),
            Linear(1024, 28*28),
            Activation(tanh, tanh_derivative),
        ]
        self.grad = True
        
    def __call__(self, x):
        return self.forward(x)
        
    def set_optimizer(self, optimizer):
        for layer in self.layers:
            if 'set_optimizer' in layer.__dir__():
                layer.set_optimizer(optimizer)

    def forward(self, x):
        self.input = x
        for layer in self.layers:
            x = layer(x, self.grad)
        return x
    
    def backward(self, output_error):
        for layer in reversed(self.layers):
            output_error = layer.backward(output_error)
        return output_error

In [17]:
class MnistDiscriminator:
    '''
    Discriminator class.
    Classificator which takes 28*28 image and says if it's fake or not
    '''

    def __init__(self):
        self.layers = [
            Linear(28*28, 1024),
            Activation(leaky_relu, leaky_relu_derivative),
            Linear(1024, 512),
            Activation(leaky_relu, leaky_relu_derivative),
            Linear(512, 256),
            Activation(leaky_relu, leaky_relu_derivative),
            Linear(256, 1),
            Activation(sigmoid, sigmoid_derivative),
        ]
        self.grad = True
        
    def __call__(self, x):
        return self.forward(x)
    
    def set_optimizer(self, optimizer):
        for layer in self.layers:
            if 'set_optimizer' in layer.__dir__():
                layer.set_optimizer(optimizer)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x, self.grad)
        return x
    
    def backward(self, output_error, gen_phase=False):
        if gen_phase:
            # we don't need to teach discriminator during generator learning stage
            layers = copy.deepcopy(self.layers)
            for layer in reversed(layers):
                output_error = layer.backward(output_error)
            return output_error
        
        for layer in reversed(self.layers):
            output_error = layer.backward(output_error)
        return output_error

In [18]:
class MnistGAN:
    def __init__(self, generator, discriminator):
        self.generator = generator
        self.discriminator = discriminator
        self.loss = log_loss
        self.loss_derivative = log_loss_derivative

    def disc_step(self, real_images, num_images):
        """
        One learning step for discriminator
        """
        # create noise to pass it to generator
        # noise = [n_images, 100] (100 is chosen noise size)
        noise = np.random.randn(num_images, 100)
        # generator generates fake images
        # fake_images = [n_images, 28*28]
        fake_images = self.generator(noise)
        
        # concatenate real and fake images for loss computation convenience
        # real_and_fake_images = [2*n_images, 28*28]
        real_and_fake_images = np.concatenate((real_images, fake_images), axis=0)
        # discriminator makes its predictions
        # disc_pred = [2*n_images, 1]
        disc_pred = self.discriminator(real_and_fake_images)
        
        # making target for discriminator, tell him which images are real and which are fake
        y_real = np.ones(shape=(num_images, 1))
        y_fake = np.zeros(shape=(num_images, 1))
        # y_real_and_fake = [2*n_images, 1]
        y_real_and_fake = np.concatenate((y_real, y_fake), axis=0)
        
        # compute loss for discriminator and perform backward pass
        loss = self.loss(y_real_and_fake, disc_pred)
        # output_error = [2*n_images, 1]
        output_error = self.loss_derivative(y_real_and_fake, disc_pred)
        self.discriminator.backward(output_error)
        return loss

    def gen_step(self, num_images):
        """
        One learning step for generator
        """
        # create fake images
        # noise = [n_images, 100] (100 is chosen noise size)
        noise = np.random.randn(num_images, 100)
        # fake_images = [n_images, 28*28]
        fake_images = self.generator(noise)
        
        # discriminator makes its predictions
        # disc_pred_fake = [n_images, 1]
        disc_pred_fake = self.discriminator(fake_images)
        
        # set fake images as real for generator training
        # y_fake = [n_images, 1]
        y_fake = np.ones(shape=(num_images, 1)) # we want to trick discriminator
        loss = self.loss(y_fake, disc_pred_fake)
        
        # output_error = [n_images, 1]
        output_error = self.loss_derivative(y_fake, disc_pred_fake)
        
        # output_error = [n_images, 28*28]
        output_error = self.discriminator.backward(output_error, gen_phase=True) # disc doesn't learn
        self.generator.backward(output_error)
        return loss
        
    def fit(self, x, batch_size, n_epochs, echo=True, save_period=100, save_path='./'):
        # saving initial model
        self.save_model(save_path)
        self.save_imgs(0, save_path)
        
        batch_size = batch_size or len(x)
        amount_of_batches = np.ceil(len(x) / batch_size).astype(int)

        try:
            for epoch in range(n_epochs):
                start_time = time.time()
                
                # it is good to do permutations in each epoch
                idxs = np.random.permutation(len(x))
                disc_error = 0
                gen_error = 0

                for batch_idx in range(amount_of_batches):
                    batch_slice = idxs[batch_idx * batch_size:batch_idx * batch_size + batch_size]
                    
                    # real_images = [batch_size or less (if batch is last in dataset), 28*28]
                    real_images = x[batch_slice]
                    num_images = real_images.shape[0]
                    
                    # learning phase
                    disc_loss = self.disc_step(real_images, num_images)
                    gen_loss = self.gen_step(num_images)
                    
                    disc_error += disc_loss
                    gen_error += gen_loss

                if echo:
                    print('*' * 30)
                    print(f"""Epoch {epoch} Time: {time.time()-start_time:.3f}
                    disc_loss:{disc_error / amount_of_batches}
                    gen_loss:{gen_error / amount_of_batches}""")
                    
                if save_period is not None and (epoch+1) % save_period == 0:
                    self.save_model(save_path)
                    self.save_imgs(epoch+1, save_path)
                    
        except KeyboardInterrupt:
            print('Interrupted by user')
            return self

        return self
    
    def save_model(self, path):
        with open(f'{path}/gan.pkl', 'wb') as f:
            pickle.dump(self, f)
            
    def save_imgs(self, epoch, path):
        noise = np.random.randn(10, 100)

        generated = self.generator(noise).reshape(-1, 28, 28)
        generated = 0.5 * generated + 0.5

        fig, ax = plt.subplots(2, 5, figsize=(25, 10))
        cnt = 0
        for i in range(2):
            for j in range(5):
                ax[i,j].imshow(generated[cnt,:,:], cmap='viridis')
                ax[i,j].axis('off')
                cnt += 1
        fig.savefig(f"{path}/mnist_{epoch}.png", dpi=199)
        plt.close()

In [19]:
optimizer = ADAM(learning_rate=0.0002, beta1=0.5)

generator = MnistGenerator()
discriminator = MnistDiscriminator()

generator.set_optimizer(optimizer)
discriminator.set_optimizer(optimizer)

gan = MnistGAN(generator, discriminator)

In [20]:
gan.fit(train_data, batch_size=2048, n_epochs=1000, save_period=10, save_path='./models/gan')

******************************
Epoch 0 Time: 20.975
                    disc_loss:0.9525861633319297
                    gen_loss:0.8509840700962551
******************************
Epoch 1 Time: 21.147
                    disc_loss:0.6913206833275453
                    gen_loss:0.6836755111314303
******************************
Epoch 2 Time: 20.744
                    disc_loss:0.6605337285352725
                    gen_loss:0.8521307226433646
******************************
Epoch 3 Time: 20.896
                    disc_loss:0.676342492744422
                    gen_loss:0.9361226731628595
******************************
Epoch 4 Time: 21.478
                    disc_loss:0.6941911520842641
                    gen_loss:0.7490055717523966
******************************
Epoch 5 Time: 20.868
                    disc_loss:0.7125136301090992
                    gen_loss:0.8085544598606557
******************************
Epoch 6 Time: 21.366
                    disc_loss:0.6766183479231866
      

******************************
Epoch 55 Time: 19.543
                    disc_loss:0.5642621099162594
                    gen_loss:1.252414915670838
******************************
Epoch 56 Time: 19.462
                    disc_loss:0.4665103101178718
                    gen_loss:1.3532406589170864
******************************
Epoch 57 Time: 19.532
                    disc_loss:0.5715102372571614
                    gen_loss:1.3190011056264583
******************************
Epoch 58 Time: 19.589
                    disc_loss:0.47115288399591604
                    gen_loss:1.4588062846965137
******************************
Epoch 59 Time: 19.640
                    disc_loss:0.46301601861743485
                    gen_loss:1.4176761471308068
******************************
Epoch 60 Time: 20.492
                    disc_loss:0.504198865777188
                    gen_loss:1.6500796112833862
******************************
Epoch 61 Time: 20.138
                    disc_loss:0.486151690884477

******************************
Epoch 110 Time: 20.886
                    disc_loss:0.4008674838309089
                    gen_loss:1.7736196723222897
******************************
Epoch 111 Time: 20.501
                    disc_loss:0.4651263350987829
                    gen_loss:1.6901591018378206
******************************
Epoch 112 Time: 19.579
                    disc_loss:0.4259479785639364
                    gen_loss:1.6579806715714691
******************************
Epoch 113 Time: 19.547
                    disc_loss:0.39046567086801204
                    gen_loss:1.6602529065217417
******************************
Epoch 114 Time: 19.796
                    disc_loss:0.4833672890812179
                    gen_loss:1.7520220116318106
******************************
Epoch 115 Time: 19.749
                    disc_loss:0.3870437618592121
                    gen_loss:1.5378575418558982
******************************
Epoch 116 Time: 19.574
                    disc_loss:0.3934404

******************************
Epoch 165 Time: 19.564
                    disc_loss:0.5281280668880802
                    gen_loss:1.3423592773717616
******************************
Epoch 166 Time: 19.682
                    disc_loss:0.48705651764695834
                    gen_loss:1.2996764932859393
******************************
Epoch 167 Time: 19.742
                    disc_loss:0.5264348433908418
                    gen_loss:1.3409911930263398
******************************
Epoch 168 Time: 19.563
                    disc_loss:0.49562909437218583
                    gen_loss:1.2751573029578238
******************************
Epoch 169 Time: 19.671
                    disc_loss:0.5090009920402894
                    gen_loss:1.324819544183253
******************************
Epoch 170 Time: 20.620
                    disc_loss:0.4977839551393547
                    gen_loss:1.3332158066059105
******************************
Epoch 171 Time: 20.374
                    disc_loss:0.4994431

******************************
Epoch 220 Time: 20.415
                    disc_loss:0.5248496693680313
                    gen_loss:1.252273982712948
******************************
Epoch 221 Time: 20.646
                    disc_loss:0.5227923988672613
                    gen_loss:1.279928544180483
******************************
Epoch 222 Time: 19.754
                    disc_loss:0.5394602361435131
                    gen_loss:1.294997538084169
******************************
Epoch 223 Time: 19.995
                    disc_loss:0.525530029910234
                    gen_loss:1.268335697063563
******************************
Epoch 224 Time: 19.734
                    disc_loss:0.5535250649141764
                    gen_loss:1.3325535035213654
******************************
Epoch 225 Time: 19.656
                    disc_loss:0.5172502258175695
                    gen_loss:1.278189263131394
******************************
Epoch 226 Time: 19.346
                    disc_loss:0.52092913242404

******************************
Epoch 275 Time: 19.657
                    disc_loss:0.5292631584681486
                    gen_loss:1.3469472283460466
******************************
Epoch 276 Time: 19.356
                    disc_loss:0.5574222217149424
                    gen_loss:1.3184183807852978
******************************
Epoch 277 Time: 19.743
                    disc_loss:0.5356733725584198
                    gen_loss:1.2834444237365215
******************************
Epoch 278 Time: 19.651
                    disc_loss:0.5253804971686179
                    gen_loss:1.2692951492255378
******************************
Epoch 279 Time: 19.800
                    disc_loss:0.5212511207934913
                    gen_loss:1.2430805464931847
******************************
Epoch 280 Time: 20.549
                    disc_loss:0.5303134033682723
                    gen_loss:1.2570889690718656
******************************
Epoch 281 Time: 20.565
                    disc_loss:0.52776045

******************************
Epoch 330 Time: 20.635
                    disc_loss:0.5334172890720177
                    gen_loss:1.336010134200736
******************************
Epoch 331 Time: 20.211
                    disc_loss:0.5195824602003409
                    gen_loss:1.341248059957853
******************************
Epoch 332 Time: 19.577
                    disc_loss:0.5164708091954635
                    gen_loss:1.324806688763405
******************************
Epoch 333 Time: 19.528
                    disc_loss:0.5229987467800414
                    gen_loss:1.312381910367657
******************************
Epoch 334 Time: 19.582
                    disc_loss:0.5216810696739707
                    gen_loss:1.3287385491571473
******************************
Epoch 335 Time: 19.920
                    disc_loss:0.5166365615200962
                    gen_loss:1.286030435350373
******************************
Epoch 336 Time: 19.917
                    disc_loss:0.5218778033143

******************************
Epoch 385 Time: 19.821
                    disc_loss:0.5048948907978643
                    gen_loss:1.345029196796697
******************************
Epoch 386 Time: 19.982
                    disc_loss:0.507154232089648
                    gen_loss:1.36723962789272
******************************
Epoch 387 Time: 19.815
                    disc_loss:0.5181234382055763
                    gen_loss:1.403116103043055
******************************
Epoch 388 Time: 20.189
                    disc_loss:0.5069408521388058
                    gen_loss:1.3843355213201225
******************************
Epoch 389 Time: 19.970
                    disc_loss:0.5013859555136002
                    gen_loss:1.3930112809474697
******************************
Epoch 390 Time: 20.784
                    disc_loss:0.49538237981522165
                    gen_loss:1.3998231496204936
******************************
Epoch 391 Time: 20.645
                    disc_loss:0.505496531399

******************************
Epoch 440 Time: 21.053
                    disc_loss:0.4760677147931589
                    gen_loss:1.4453244407755834
******************************
Epoch 441 Time: 20.549
                    disc_loss:0.5008545845841843
                    gen_loss:1.4526579954431982
******************************
Epoch 442 Time: 19.953
                    disc_loss:0.48076296734349067
                    gen_loss:1.487816461113093
******************************
Epoch 443 Time: 19.979
                    disc_loss:0.4812606171455397
                    gen_loss:1.4416134256948983
******************************
Epoch 444 Time: 20.014
                    disc_loss:0.48171960922272866
                    gen_loss:1.4372918921142814
******************************
Epoch 445 Time: 19.723
                    disc_loss:0.47265659175047614
                    gen_loss:1.4251249639323411
******************************
Epoch 446 Time: 20.170
                    disc_loss:0.496472

******************************
Epoch 495 Time: 20.178
                    disc_loss:0.4697846415729225
                    gen_loss:1.4952721842872736
******************************
Epoch 496 Time: 19.929
                    disc_loss:0.4621164876246903
                    gen_loss:1.506110831774709
******************************
Epoch 497 Time: 19.820
                    disc_loss:0.4732362215604345
                    gen_loss:1.5159363541569941
******************************
Epoch 498 Time: 20.170
                    disc_loss:0.486973692680159
                    gen_loss:1.5327807643232636
******************************
Epoch 499 Time: 20.057
                    disc_loss:0.46989748199900816
                    gen_loss:1.5384639595862115
******************************
Epoch 500 Time: 20.903
                    disc_loss:0.4800950630383466
                    gen_loss:1.5467214970287382
******************************
Epoch 501 Time: 20.966
                    disc_loss:0.464005189

******************************
Epoch 550 Time: 20.759
                    disc_loss:0.4509181307900915
                    gen_loss:1.5430050219319793
******************************
Epoch 551 Time: 20.437
                    disc_loss:0.45941859050051165
                    gen_loss:1.5745735618154575
******************************
Epoch 552 Time: 20.035
                    disc_loss:0.4579922246982361
                    gen_loss:1.5404804100849223
******************************
Epoch 553 Time: 20.315
                    disc_loss:0.45817290279627754
                    gen_loss:1.553662173506296
******************************
Epoch 554 Time: 19.770
                    disc_loss:0.4507570770471999
                    gen_loss:1.5760881882249629
******************************
Epoch 555 Time: 20.073
                    disc_loss:0.45876762110532254
                    gen_loss:1.569533474536002
******************************
Epoch 556 Time: 19.447
                    disc_loss:0.4605197

******************************
Epoch 605 Time: 19.729
                    disc_loss:0.4342027570225707
                    gen_loss:1.6121673441151416
******************************
Epoch 606 Time: 19.769
                    disc_loss:0.42044577865864485
                    gen_loss:1.5452643681688119
******************************
Epoch 607 Time: 19.770
                    disc_loss:0.43671039399013023
                    gen_loss:1.6287996387926609
******************************
Epoch 608 Time: 19.652
                    disc_loss:0.42991166930238933
                    gen_loss:1.625460241750336
******************************
Epoch 609 Time: 19.703
                    disc_loss:0.4599576962059156
                    gen_loss:1.649569435454935
******************************
Epoch 610 Time: 20.688
                    disc_loss:0.43883624108645497
                    gen_loss:1.6118029785090766
******************************
Epoch 611 Time: 20.315
                    disc_loss:0.434196

******************************
Epoch 660 Time: 20.969
                    disc_loss:0.42142329976376863
                    gen_loss:1.6415253291959173
******************************
Epoch 661 Time: 20.021
                    disc_loss:0.42994711268652785
                    gen_loss:1.681838976534738
******************************
Epoch 662 Time: 19.610
                    disc_loss:0.4122846715933059
                    gen_loss:1.6706168516289264
******************************
Epoch 663 Time: 19.680
                    disc_loss:0.42943241427392714
                    gen_loss:1.6872069111533374
******************************
Epoch 664 Time: 19.637
                    disc_loss:0.42322325191108123
                    gen_loss:1.6598250970628636
******************************
Epoch 665 Time: 20.753
                    disc_loss:0.4243229850853265
                    gen_loss:1.675256561542428
******************************
Epoch 666 Time: 20.630
                    disc_loss:0.721633

******************************
Epoch 715 Time: 19.589
                    disc_loss:0.408328150330422
                    gen_loss:1.7424332186141023
******************************
Epoch 716 Time: 19.820
                    disc_loss:0.4117739341274368
                    gen_loss:1.7860300151983668
******************************
Epoch 717 Time: 19.911
                    disc_loss:0.41885459155151766
                    gen_loss:1.7649418157047547
******************************
Epoch 718 Time: 19.983
                    disc_loss:0.3925198464037533
                    gen_loss:1.724270111013045
******************************
Epoch 719 Time: 19.815
                    disc_loss:0.4331877565744295
                    gen_loss:1.855936184366653
******************************
Epoch 720 Time: 20.826
                    disc_loss:1.0630699933340133
                    gen_loss:2.180415996384505
******************************
Epoch 721 Time: 20.500
                    disc_loss:0.52993234398

******************************
Epoch 770 Time: 21.668
                    disc_loss:0.3825765212213212
                    gen_loss:1.779055385497929
******************************
Epoch 771 Time: 21.106
                    disc_loss:0.39335115059449055
                    gen_loss:1.767118576282534
******************************
Epoch 772 Time: 19.873
                    disc_loss:0.38688043480194184
                    gen_loss:1.7971729635678917
******************************
Epoch 773 Time: 20.137
                    disc_loss:0.38670785809475655
                    gen_loss:1.7864542730248198
******************************
Epoch 774 Time: 19.900
                    disc_loss:0.38855464979127213
                    gen_loss:1.7869614707357173
******************************
Epoch 775 Time: 20.347
                    disc_loss:0.3945108355035944
                    gen_loss:1.8079925942751953
******************************
Epoch 776 Time: 20.083
                    disc_loss:0.396064

******************************
Epoch 825 Time: 19.512
                    disc_loss:0.3659694712346205
                    gen_loss:1.8051008532491597
******************************
Epoch 826 Time: 19.686
                    disc_loss:0.3738619995245405
                    gen_loss:1.8545748767443666
******************************
Epoch 827 Time: 19.808
                    disc_loss:0.3706522114985007
                    gen_loss:1.8426379121288246
******************************
Epoch 828 Time: 19.707
                    disc_loss:0.36763883982845674
                    gen_loss:1.8243732356478906
******************************
Epoch 829 Time: 19.792
                    disc_loss:0.3737409271480178
                    gen_loss:1.849949131674037
******************************
Epoch 830 Time: 20.764
                    disc_loss:0.3752308869369664
                    gen_loss:1.8690053189714544
******************************
Epoch 831 Time: 20.541
                    disc_loss:0.38041998

******************************
Epoch 880 Time: 20.757
                    disc_loss:0.37068859572646984
                    gen_loss:1.9097346397418957
******************************
Epoch 881 Time: 19.729
                    disc_loss:0.35266146618043975
                    gen_loss:1.9673586025853753
******************************
Epoch 882 Time: 19.654
                    disc_loss:0.3645503660044822
                    gen_loss:1.9635157697258339
******************************
Epoch 883 Time: 19.605
                    disc_loss:0.36998811717163
                    gen_loss:1.9540979730609993
******************************
Epoch 884 Time: 19.431
                    disc_loss:0.37444680854076884
                    gen_loss:1.9320252670779177
******************************
Epoch 885 Time: 19.816
                    disc_loss:0.3638098283009853
                    gen_loss:1.9616235772596577
******************************
Epoch 886 Time: 21.001
                    disc_loss:0.3672211

******************************
Epoch 935 Time: 19.876
                    disc_loss:0.3671163986400781
                    gen_loss:1.845291275718326
******************************
Epoch 936 Time: 19.678
                    disc_loss:0.36058294931080975
                    gen_loss:1.8576163259911536
******************************
Epoch 937 Time: 20.158
                    disc_loss:0.3357738209418278
                    gen_loss:1.8865327817199617
******************************
Epoch 938 Time: 19.563
                    disc_loss:0.33301584375383464
                    gen_loss:1.8843453942328043
******************************
Epoch 939 Time: 19.877
                    disc_loss:0.33954025003605076
                    gen_loss:1.9284356092891535
******************************
Epoch 940 Time: 20.707
                    disc_loss:0.3330927826698964
                    gen_loss:1.9098396922515075
******************************
Epoch 941 Time: 20.206
                    disc_loss:0.333823

******************************
Epoch 990 Time: 21.370
                    disc_loss:0.38453175679825546
                    gen_loss:2.1082809106518647
******************************
Epoch 991 Time: 20.434
                    disc_loss:0.3366199568031604
                    gen_loss:2.0511951925329757
******************************
Epoch 992 Time: 19.794
                    disc_loss:0.3269109094983551
                    gen_loss:2.043931849141099
******************************
Epoch 993 Time: 19.598
                    disc_loss:0.3218386265927546
                    gen_loss:2.060815985355429
******************************
Epoch 994 Time: 19.690
                    disc_loss:0.33773091573510144
                    gen_loss:2.107911435339888
******************************
Epoch 995 Time: 19.603
                    disc_loss:0.34322815806925044
                    gen_loss:2.039864572443089
******************************
Epoch 996 Time: 19.632
                    disc_loss:0.330771902

<__main__.MnistGAN at 0x1be235673c8>

In [28]:
# create gif

import glob
import contextlib
from PIL import Image

# filepaths
fp_in = "./models/gan/mnist_*.png"
fp_out = "./models/gan/mnist.gif"

# use exit stack to automatically close opened images
with contextlib.ExitStack() as stack:

    # lazily load images
    imgs = (stack.enter_context(Image.open(f))
            for f in sorted(glob.glob(fp_in)))

    # extract  first image from iterator
    img = next(imgs)

    # https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#gif
    img.save(fp=fp_out, format='GIF', append_images=imgs,
             save_all=True, duration=150, loop=0)

In [23]:
from IPython.display import HTML
with open('./style.css') as f:
    style = f.read()
HTML(style)