<a href="https://colab.research.google.com/github/GarlandZhang/gans_in_action_notes/blob/master/sgan_impl.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%matplotlib inline

In [None]:
import matplotlib.pyplot as plt
import numpy as np

In [None]:
from keras import backend as K
from keras.datasets import mnist
from keras.layers import Activation, BatchNormalization, Concatenate, Dense, Dropout, Flatten, Input, Lambda, Reshape
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.models import Model, Sequential
from keras.optimizers import Adam
from keras.utils import to_categorical

In [None]:
class Dataset:
  def __init__(self, num_labeled):
    self.num_labeled = num_labeled
    (self.x_train, self.y_train), (self.x_test, self.y_test) = mnist.load_data()

    self.x_train = self.preprocess_imgs(self.x_train)
    self.y_train = self.preprocess_labels(self.y_train)

    self.x_test = self.preprocess_imgs(self.x_test)
    self.y_test = self.preprocess_labels(self.y_test)

  def preprocess_imgs(self, x):
    x = x.astype(np.float32) / 127.5 - 1
    x = np.expand_dims(x, axis=3)
    return x

  def preprocess_labels(self, y):
    return y.reshape(-1, 1)

  def batch_labeled(self, batch_size):
    idx = np.random.randint(0, self.num_labeled, batch_size)
    imgs = self.x_train[idx]
    labels = self.y_train[idx]
    return imgs, labels
  
  def batch_unlabeled(self, batch_size):
    idx = np.random.randint(self.num_labeled, self.x_train.shape[0], batch_size)
    imgs = self.x_train[idx]
    return imgs

  def training_set(self):
    x_train = self.x_train[range(self.num_labeled)] # self.x_train is the ACTUAL training set; this is OUR training set
    y_train = self.y_train[range(self.num_labeled)]
    return x_train, y_train

  def test_set(self):
    return x_test, y_test

In [None]:
def build_generator(z_dim):
  model = Sequential([
                      Dense(256 * 7 * 7, input_dim=z_dim),
                      Reshape((7, 7, 256)),
                      Conv2DTranspose(128, kernel_size=3, strides=2, padding='same'), # 14 x 14 x 128
                      BatchNormalization(),
                      LeakyReLU(alpha=0.01),

                      Conv2DTranspose(64, kernel_size=3, strides=1, padding='same'), # 14 x 14 x 64
                      BatchNormalization(),
                      LeakyReLU(alpha=0.01),

                      Conv2DTranspose(1, kernel_size=3, strides=2, padding='same'), # 28 x 28 x 1
                      Activation('tanh')                      

  ])

  return model

In [None]:
def build_discriminator(img_shape, num_classes):
  model = Sequential([
                      Conv2D(32, kernel_size=3, strides=2, input_shape=img_shape, padding='same'),
                      LeakyReLU(alpha=0.01),

                      Conv2D(64, kernel_size=3, strides=2, input_shape=img_shape, padding='same'),
                      BatchNormalization(),
                      LeakyReLU(alpha=0.01),

                      Conv2D(128, kernel_size=3, strides=2, input_shape=img_shape, padding='same'),
                      BatchNormalization(),
                      LeakyReLU(alpha=0.01),
                      Dropout(0.5), # "the dropout layer is added after batch normalization and not the other way around; this has shown to have superior performance due to the interplay between the two techniques"
                      Flatten(),
                      Dense(num_classes)
  ])
  return model

In [None]:
def build_discriminator_supervised(discriminator_net):
  model = Sequential([
                      discriminator_net,
                      Activation('softmax')
  ])
  return model

In [None]:
def build_discriminator_unsupervised(discriminator_net):
  def predict(x): # " we transform the output of the 10 neurons (from the core Discriminator network) into a binary, real-versus-fake prediction"
    prediction = 1. - (1. / (K.sum(K.exp(x), axis=-1, keepdims=True) + 1.)) # suppose any of the x values are large; then we will have large divider so we have small value therefore we get 1 as output; otherwise if small theen we have 0 as output
    return prediction

  model = Sequential([
                    discriminator_net,
                    Lambda(predict)  
  ])

  return model

In [None]:
def build_gan(generator, discriminator):
  model = Sequential([
                      generator,
                      discriminator
  ])
  return model

In [None]:
discriminator_net = build_discriminator(img_shape, num_classes) # this is mutable so when we train supervised, it inherently trains unsupervised
discriminator_supervised = build_discriminator_supervised(discriminator_net)
discriminator_supervised.compile(loss='categorical_crossentropy', metrics=['accuracy'], optimizer=Adam())
discriminator_unsupervised = build_discriminator_unsupervised(discriminator_net)
discriminator_unsupervised.compile(loss='binary_crossentropy', metrics=['accuracy'], optimizer=Adam())
generator = build_generator(z_dim)
discriminator_unsupervised.trainable = False
gan = build_gan(generator, discriminator_unsupervised)
gan.compile(loss='binary_crossentropy', metrics=['accuracy'], optimizer=Adam())

In [None]:
supervised_losses = []
iteration_checkpoints = []
def train(iterations, batch_size, sample_interval):
  real = np.ones((batch_size, 1))
  fake = np.zeros((batch_size, 1))
  for iteration in range(iterations):
    imgs, labels = dataset.batch_labeled(batch_size)
    labels = to_categorical(labels, num_classes=num_classes)
    imgs_unlabeled = dataset.batch_unlabeled(batch_size)
    z = np.random.normal(0, 1, (batch_size, z_dim)) # latent space vector
    gen_imgs = generator.predict(z)
    d_loss_supervised, sup_acc = discriminator_supervised.train_on_batch(imgs, labels)
    d_loss_real, un_real_acc = discriminator_unsupervised.train_on_batch(imgs_unlabeled, real)
    d_loss_fake, un_fake_acc = discriminator_unsupervised.train_on_batch(gen_imgs, fake)
    d_loss_unsupervised = 0.5 * np.add(d_loss_real, d_loss_fake)

    z = np.random.normal(0, 1, (batch_size, z_dim))
    g_loss = gan.train_on_batch(z, real)

    if (iteration + 1) % sample_interval == 0:
      supervised_losses.append(d_loss_supervised)
      iteration_checkpoints.append(iteration + 1)
      print('{iteration + 1} [D loss supervised: {d_loss_supervised}, acc.: {100 * sup_acc}] [D loss unsupervised: {d_loss_unsupervised}] [G loss: {g_loss}]')

In [None]:
num_labeled = 100
dataset = Dataset(num_labeled)

In [None]:
img_rows = 28
img_cols = 28
channels = 1

img_shape = (img_rows, img_cols, channels)

z_dim = 100

num_classes = 10

iterations = 8000
batch_size = 32
sample_interval = 800
train(iterations, batch_size, sample_interval)

  'Discrepancy between trainable weights and collected trainable'


TypeError: ignored