<a href="https://colab.research.google.com/github/Machine-Learning-Tokyo/Intro-to-GANs/blob/master/more_advanced/Fashion_DCGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fashion DCGAN

### Imports

In [0]:
from keras.models import Model
from keras.layers import Input, Dense, BatchNormalization, Reshape, Flatten
from keras.layers import UpSampling2D, Conv2D
from keras.layers.advanced_activations import LeakyReLU
from keras.initializers import RandomNormal
from keras.datasets import fashion_mnist
from keras.optimizers import Adam

import numpy as np
import matplotlib.pyplot as plt
from IPython.display import Image

### Function to build the generator

In [0]:
def build_generator(noise_size, img_shape):
  
  filters = 512
  k_size = 5, 5
  k_init = RandomNormal(0, 0.02)
  
  noise = Input((noise_size,))
  
  x = Dense(4*4*filters, activation='relu', kernel_initializer=k_init)(noise)
  x = Reshape((4, 4, filters))(x)  # 4, 4
  x = BatchNormalization()(x)
  x = UpSampling2D()(x)  # 8, 8
  
  x = Conv2D(filters // 2, k_size, padding='same', activation='relu', kernel_initializer=k_init)(x)
  x = BatchNormalization()(x)
  x = UpSampling2D()(x)  # 16, 16
  
  x = Conv2D(filters // 4, k_size, padding='same', activation='relu', kernel_initializer=k_init)(x)
  x = BatchNormalization()(x)
  x = UpSampling2D()(x)  # 32, 32
  
  img = Conv2D(img_shape[-1], k_size, padding='same', activation='tanh', kernel_initializer=k_init)(x)
  
  generator = Model(noise, img)
  return generator

### Function to build the discriminator

In [0]:
def build_discriminator(img_shape):
  
  filters = 512
  k_size = 5, 5
  k_init = RandomNormal(0, 0.02)
  
  img = Input(img_shape)  # 32, 32
  
  x = Conv2D(filters // 4, k_size, strides=(2, 2), padding='same', kernel_initializer=k_init)(img)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)  # 16, 16
  
  x = Conv2D(filters // 2, k_size, strides=(2, 2), padding='same', kernel_initializer=k_init)(img)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)  # 8, 8
  
  x = Conv2D(filters, k_size, strides=(2, 2), padding='same', kernel_initializer=k_init)(img)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)  # 4, 4
  
  x = Flatten()(x)
  validity = Dense(1, activation='sigmoid')(x)
  
  discriminator = Model(img, validity)
  return discriminator

### Function to compile the models

In [0]:
def get_compiled_models(generator, discriminator, noise_size):
  
  optimizer = Adam(0.0002, 0.5)
  
  discriminator.compile(optimizer, loss='binary_crossentropy', metrics=['accuracy'])
  discriminator.trainable = False
  
  noise = Input((noise_size,))
  img = generator(noise)
  validity = discriminator(img)
  combined = Model(noise, validity)
  
  combined.compile(optimizer, loss='binary_crossentropy')
  
  return generator, discriminator, combined

### Function to sample and save generated images

In [0]:
def sample_imgs(generator, noise_size, step, plot_img=True, cond=False, num_classes=10):
  np.random.seed(0)
  
  r, c = num_classes, 10
  if cond:
    noise = np.random.normal(0, 1, (c, noise_size))
    noise = np.tile(noise, (r, 1))

    sampled_labels = np.arange(r).reshape(-1, 1)
    sampled_labels = to_categorical(sampled_labels, r)
    sampled_labels = np.repeat(sampled_labels, c, axis=0)

    imgs = generator.predict([noise, sampled_labels])
  else:
    noise = np.random.normal(0, 1, (r*c, noise_size))
    imgs = generator.predict_on_batch(noise)
  
  imgs = imgs / 2 + 0.5
  imgs = np.reshape(imgs, [r, c, imgs.shape[1], imgs.shape[2], -1])
  
  figsize = 1 * c, 1 * r
  fig, axs = plt.subplots(r, c, figsize=figsize)
  
  for i in range(r):
    for j in range(c):
      img = imgs[i, j] if len(imgs.shape) == 4 else imgs[i, j, :, :, 0]
      axs[i, j].imshow(img, cmap='gray')
      axs[i, j].axis('off')
  plt.subplots_adjust(wspace=0.1, hspace=0.1)
  fig.savefig(f'/content/images/{step}.png')
  if plot_img:
    plt.show()
  plt.close()
  
  np.random.seed(None)

### Function to train the models

In [0]:
def train(models, noise_size, img_shape, batch_size, steps):
  
  generator, discriminator, combined = models
  #get real data
  (X_train, _), (X_val, _) = fashion_mnist.load_data()
  fashion_mnist_imgs = np.concatenate((X_train, X_val)) / 127.5 - 1
  fashion_mnist_imgs = np.pad(fashion_mnist_imgs, ((0, 0), (2, 2), (2, 2)), 'constant', constant_values=-1)
  fashion_mnist_imgs = np.expand_dims(fashion_mnist_imgs, axis=-1)
  
  for step in range(1, steps + 1):
    # train discriminator
    inds = np.random.randint(0, fashion_mnist_imgs.shape[0], batch_size)
    real_imgs = fashion_mnist_imgs[inds]
    real_validity = np.ones(batch_size)
    
    noise = np.random.normal(0, 1, (batch_size, noise_size))
    gen_imgs = generator.predict(noise)
    gen_validity = np.zeros(batch_size)
    
    r_loss = discriminator.train_on_batch(real_imgs, real_validity)
    g_loss = discriminator.train_on_batch(gen_imgs, gen_validity)
    disc_loss = np.add(r_loss, g_loss) / 2
    
    # train generator
    noise = np.random.normal(0, 1, (batch_size, noise_size))
    gen_validity = np.ones(batch_size)
    gen_loss = combined.train_on_batch(noise, gen_validity)
    
    #print progress
    if step % 50 == 0:
      print('step: %d, D_loss: %f, D_accuracy: %.2f%%, G_loss: %f' % (step, disc_loss[0],
                                                                      disc_loss[1] * 100, gen_loss))
    
    # save_samples
    if step % 200 == 0:
      sample_imgs(generator, noise_size, step)

### Define hyperparameters

In [0]:
%rm -r /content/images
%mkdir /content/images
noise_size = 100
img_shape = 32, 32, 1
batch_size = 64
steps = 10000

rm: cannot remove '/content/images': No such file or directory


### Generate the models

In [0]:
generator = build_generator(noise_size, img_shape)
discriminator = build_discriminator(img_shape)
compiled_models = get_compiled_models(generator, discriminator, noise_size)

### Train the models

In [0]:
train(compiled_models, noise_size, img_shape, batch_size, steps)

## Plot resutls

In [0]:
%%capture
!pip install import_ipynb
%cd /content
%rm -r /content/0a16ae419d9eba160ddb4f48862fb9e2
!git clone https://gist.github.com/dkatsios/0a16ae419d9eba160ddb4f48862fb9e2.git
%cd /content/0a16ae419d9eba160ddb4f48862fb9e2
import import_ipynb
from IPython.display import HTML
from AnimationDisplay import plot_results
%cd /content

In [0]:
path = '/content/images/{}.png'
iterator = range(200, steps+1, 200)
HTML(plot_results(path, iterator).to_jshtml())

## Download images and generator

In [0]:
gen_path = '/content/gen.h5'
generator.save(gen_path)
from google.colab import files
files.download(gen_path)