In [110]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [0]:
# importing the requirements
from numpy import expand_dims
from numpy import zeros
from numpy import ones
from numpy.random import randn
from numpy.random import randint
from keras.datasets.cifar10 import load_data
from keras.optimizers import Adam
from keras.models import Model
from keras.layers import Input
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers import Flatten
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import Dropout
from keras.layers import Embedding
from keras.layers import Concatenate
from numpy import asarray
from keras.models import load_model
from matplotlib import pyplot
from keras.layers import Lambda
import tensorflow as tf
from keras.callbacks import TensorBoard
from keras.applications.inception_v3 import InceptionV3
from numpy import vstack
from keras.datasets.cifar10 import load_data
import numpy
from numpy import cov
from numpy import trace
from numpy import iscomplexobj
from numpy import asarray
from numpy.random import shuffle
from scipy.linalg import sqrtm
from keras.applications.inception_v3 import preprocess_input
from skimage.transform import resize
from keras.datasets import cifar10
from keras import backend as K
from keras.engine import *
from keras.legacy import interfaces
from keras import activations
from keras import initializers
from keras import regularizers
from keras import constraints
from keras.utils.generic_utils import func_dump
from keras.utils.generic_utils import func_load
from keras.utils.generic_utils import deserialize_keras_object
from keras.utils.generic_utils import has_arg
from keras.utils import conv_utils
from keras.legacy import interfaces
from keras.layers import Dense, Conv1D, Conv2D, Conv3D, Conv2DTranspose, Embedding
import tensorflow as tf

In [0]:
# tesnorboard code snippet
tensorboard = TensorBoard(
  log_dir='log/sagan_log',
  histogram_freq=0,
  batch_size=128,
  write_graph=True,
  write_grads=True
)

In [0]:
# The CONVSN2D and DENSESN classes are passed to our discriminator hence it trains with spectralnormalization applied 
class DenseSN(Dense):
    def build(self, input_shape):
        assert len(input_shape) >= 2
        input_dim = input_shape[-1]
        self.kernel = self.add_weight(shape=(input_dim, self.units),
                                      initializer=self.kernel_initializer,
                                      name='kernel',
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)
        if self.use_bias:
            self.bias = self.add_weight(shape=(self.units,),
                                        initializer=self.bias_initializer,
                                        name='bias',
                                        regularizer=self.bias_regularizer,
                                        constraint=self.bias_constraint)
        else:
            self.bias = None
        self.u = self.add_weight(shape=tuple([1, self.kernel.shape.as_list()[-1]]),
                                 initializer=initializers.RandomNormal(0, 1),
                                 name='sn',
                                 trainable=False)
        self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim})
        self.built = True
        
    def call(self, inputs, training=None):
        def _l2normalize(v, eps=1e-12):
            return v / (K.sum(v ** 2) ** 0.5 + eps)
        def power_iteration(W, u):
            _u = u
            _v = _l2normalize(K.dot(_u, K.transpose(W)))
            _u = _l2normalize(K.dot(_v, W))
            return _u, _v
        W_shape = self.kernel.shape.as_list()
        #Flatten the Tensor
        W_reshaped = K.reshape(self.kernel, [-1, W_shape[-1]])
        _u, _v = power_iteration(W_reshaped, self.u)
        #Calculate Sigma
        sigma=K.dot(_v, W_reshaped)
        sigma=K.dot(sigma, K.transpose(_u))
        #normalize it
        W_bar = W_reshaped / sigma
        #reshape weight tensor
        if training in {0, False}:
            W_bar = K.reshape(W_bar, W_shape)
        else:
            with tf.control_dependencies([self.u.assign(_u)]):
                 W_bar = K.reshape(W_bar, W_shape)  
        output = K.dot(inputs, W_bar)
        if self.use_bias:
            output = K.bias_add(output, self.bias, data_format='channels_last')
        if self.activation is not None:
            output = self.activation(output)
        return output 

In [0]:
class ConvSN2D(Conv2D):

    def build(self, input_shape):
        if self.data_format == 'channels_first':
            channel_axis = 1
        else:
            channel_axis = -1
        if input_shape[channel_axis] is None:
            raise ValueError('The channel dimension of the inputs '
                             'should be defined. Found `None`.')
        input_dim = input_shape[channel_axis]
        kernel_shape = self.kernel_size + (input_dim, self.filters)

        self.kernel = self.add_weight(shape=kernel_shape,
                                      initializer=self.kernel_initializer,
                                      name='kernel',
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)

        if self.use_bias:
            self.bias = self.add_weight(shape=(self.filters,),
                                        initializer=self.bias_initializer,
                                        name='bias',
                                        regularizer=self.bias_regularizer,
                                        constraint=self.bias_constraint)
        else:
            self.bias = None
            
        self.u = self.add_weight(shape=tuple([1, self.kernel.shape.as_list()[-1]]),
                         initializer=initializers.RandomNormal(0, 1),
                         name='sn',
                         trainable=False)
        
        # Set input spec.
        self.input_spec = InputSpec(ndim=self.rank + 2,
                                    axes={channel_axis: input_dim})
        self.built = True
    def call(self, inputs, training=None):
        def _l2normalize(v, eps=1e-12):
            return v / (K.sum(v ** 2) ** 0.5 + eps)
        def power_iteration(W, u):
            #Accroding the paper, we only need to do power iteration one time.
            _u = u
            _v = _l2normalize(K.dot(_u, K.transpose(W)))
            _u = _l2normalize(K.dot(_v, W))
            return _u, _v
        #Spectral Normalization
        W_shape = self.kernel.shape.as_list()
        #Flatten the Tensor
        W_reshaped = K.reshape(self.kernel, [-1, W_shape[-1]])
        _u, _v = power_iteration(W_reshaped, self.u)
        #Calculate Sigma
        sigma=K.dot(_v, W_reshaped)
        sigma=K.dot(sigma, K.transpose(_u))
        #normalize it
        W_bar = W_reshaped / sigma
        #reshape weight tensor
        if training in {0, False}:
            W_bar = K.reshape(W_bar, W_shape)
        else:
            with tf.control_dependencies([self.u.assign(_u)]):
                W_bar = K.reshape(W_bar, W_shape)
                
        outputs = K.conv2d(
                inputs,
                W_bar,
                strides=self.strides,
                padding=self.padding,
                data_format=self.data_format,
                dilation_rate=self.dilation_rate)
        if self.use_bias:
            outputs = K.bias_add(
                outputs,
                self.bias,
                data_format=self.data_format)
        if self.activation is not None:
            return self.activation(outputs)
        return outputs

In [0]:
# conv used to reshape our data
def conv(x, channels, kernel=[1,1,1,1], stride=1):
  print(x.shape)
  print(kernel)
   

  filter_size = 4 
  input_channels = channels
  output_filters = channels

  x = tf.nn.conv2d(x, filters=tf.Variable(tf.truncated_normal([filter_size, filter_size, input_channels, output_filters], stddev=0.5)), strides=[1,1,1,1] , padding='SAME')
  return x

In [0]:
def attention(x):
  f = conv(x, 128 // 8, kernel=1, stride=1) # [bs, h, w, c']
  g = conv(x, 128 // 8, kernel=1, stride=1) # [bs, h, w, c']
  h = conv(x, 128, kernel=1, stride=1) # [bs, h, w, c]

  s = tf.matmul(g, f, transpose_b=True) # # [bs, N, N]
  beta = tf.nn.softmax(s)  # attention map
  o = tf.matmul(beta, h) # [bs, N, C]

  with tf.variable_scope("gamma", reuse=tf.AUTO_REUSE):
    gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0))

  o = conv(o, 128, kernel=1, stride=1)
  x = gamma * o + x

  return x

In [0]:
# training the disc. using categorical_hinge as loss and passing our convolution layers to facilitate upsampling/downsampling
# adding our attention layer
def define_discriminator(in_shape=(32,32,3), n_classes=10):
  in_label = Input(shape=(1,))
  li = Embedding(n_classes, 50)(in_label)

  n_nodes = in_shape[0] * in_shape[1]
  li = DenseSN(n_nodes)(li)
  li = Reshape((in_shape[0], in_shape[1], 1))(li)
  in_image = Input(shape=in_shape)
  merge = Concatenate()([in_image, li])
  fe = ConvSN2D(128, (3,3), strides=(2,2), padding='same')(merge)
  fe = LeakyReLU(alpha=0.2)(fe)
  fe = Lambda(attention)(fe)
  fe = ConvSN2D(128, (3,3), strides=(2,2), padding='same')(fe)
  fe = LeakyReLU(alpha=0.2)(fe)
  fe = Flatten()(fe)
  fe = Dropout(0.4)(fe)
  out_layer = DenseSN(1, activation='sigmoid')(fe)
  model = Model([in_image, in_label], out_layer)
  model.compile(loss='categorical_hinge', optimizer=Adam(lr=0.0002, beta_1=0.5), metrics=['accuracy'])
  return model


In [0]:
# adding attention layer using lambda
def define_generator(latent_dim, n_classes=10):
  in_label = Input(shape=(1,))
  li = Embedding(n_classes, 50)(in_label)
  n_nodes = 8 * 8
  li = Dense(n_nodes)(li)
  li = Reshape((8, 8, 1))(li)
  in_lat = Input(shape=(latent_dim,))
  n_nodes = 128 * 8 * 8
  gen = Dense(n_nodes)(in_lat)
  gen = LeakyReLU(alpha=0.2)(gen)
  gen = Reshape((8, 8, 128))(gen)
  merge = Concatenate()([gen, li])
  gen = Conv2DTranspose(128, (4,4), strides=(2,2), padding='same')(merge)
  gen = LeakyReLU(alpha=0.2)(gen)
  gen = Lambda(attention)(gen)
  gen = Conv2DTranspose(128, (4,4), strides=(2,2), padding='same')(gen)
  gen = LeakyReLU(alpha=0.2)(gen)
  out_layer = Conv2D(3, (3,3), activation='tanh', padding='same')(gen)
  model = Model([in_lat, in_label], out_layer)
  return model


In [0]:
def define_gan(g_model, d_model):
	d_model.trainable = False
	gen_noise, gen_label = g_model.input
	gen_output = g_model.output
	gan_output = d_model([gen_output, gen_label])
	model = Model([gen_noise, gen_label], gan_output)
	model.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5))
	return model

In [0]:
tensorboard.set_model(define_gan)

In [0]:
# The following functions are common to the DCGAN and SAGAN implementations
def load_real_samples():
	# load dataset
  (trainX, trainy), (_, _) = load_data()
  X = trainX
  X = X.astype('float32')
	# scale from [0,255] to [-1,1]
  X = (X - 127.5) / 127.5
  return [X, trainy]


In [0]:
def generate_real_samples(dataset, n_samples):
	images, labels = dataset
	ix = randint(0, images.shape[0], n_samples)
	X, labels = images[ix], labels[ix]
	y = ones((n_samples, 1))
	return [X, labels], y


In [0]:
def generate_latent_points(latent_dim, n_samples, n_classes=10):
	x_input = randn(latent_dim * n_samples)
	z_input = x_input.reshape(n_samples, latent_dim)
	labels = randint(0, n_classes, n_samples)
	return [z_input, labels]


In [0]:
def generate_fake_samples(generator, latent_dim, n_samples):
	z_input, labels_input = generate_latent_points(latent_dim, n_samples)
	images = generator.predict([z_input, labels_input])
	y = zeros((n_samples, 1))
	return [images, labels_input], y


In [0]:
def generate_latent_points(latent_dim, n_samples, n_classes=10):
	x_input = randn(latent_dim * n_samples)
	z_input = x_input.reshape(n_samples, latent_dim)
	labels = randint(0, n_classes, n_samples)
	return [z_input, labels]

def save_plot(examples, n):
	for i in range(n * n):
		pyplot.subplot(n, n, 1 + i)
		pyplot.axis('off')
		pyplot.imshow(examples[i, :, :, 0], cmap='gray_r')
	pyplot.show()

In [0]:
def scale_images(images, new_shape):
	images_list = list()
	for image in images:
		new_image = resize(image, new_shape, 0)
		images_list.append(new_image)
	return asarray(images_list)

def calculate_fid(model, images1, images2):
	act1 = model.predict(images1)
	act2 = model.predict(images2)
	mu1, sigma1 = act1.mean(axis=0), cov(act1, rowvar=False)
	mu2, sigma2 = act2.mean(axis=0), cov(act2, rowvar=False)
	ssdiff = numpy.sum((mu1 - mu2)**2.0)
	covmean = sqrtm(sigma1.dot(sigma2))
	if iscomplexobj(covmean):
		covmean = covmean.real
	fid = ssdiff + trace(sigma1 + sigma2 - 2.0 * covmean)
	return fid

In [0]:
def train(g_model, d_model, gan_model, dataset, latent_dim, n_epochs=50, n_batch=128):
	bat_per_epo = int(dataset[0].shape[0] / n_batch)
	half_batch = int(n_batch / 2)
	for i in range(n_epochs):
		for j in range(bat_per_epo):
			[X_real, labels_real], y_real = generate_real_samples(dataset, half_batch)
			d_loss1, _ = d_model.train_on_batch([X_real, labels_real], y_real)
			[X_fake, labels], y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
			d_loss2, _ = d_model.train_on_batch([X_fake, labels], y_fake)
			[z_input, labels_input] = generate_latent_points(latent_dim, n_batch)
			y_gan = ones((n_batch, 1))
			g_loss = gan_model.train_on_batch([z_input, labels_input], y_gan)

			print('>%d, %d/%d, d1=%.3f, d2=%.3f g=%.3f' %
				(i+1, j+1, bat_per_epo, d_loss1, d_loss2, g_loss))
	g_model.save('cgan_generator.h5')


In [0]:
latent_dim = 100
d_model = define_discriminator()
g_model = define_generator(latent_dim)
gan_model = define_gan(g_model, d_model)
dataset = load_real_samples()
train(g_model, d_model, gan_model, dataset, latent_dim)

In [0]:
# example of loading the generator model and generating images
from numpy import asarray
from numpy.random import randn
from numpy.random import randint
from keras.models import load_model
from matplotlib import pyplot
 
def generate_latent_points(latent_dim, n_samples, n_classes=10):
	x_input = randn(latent_dim * n_samples)
	z_input = x_input.reshape(n_samples, latent_dim)
	labels = randint(0, n_classes, n_samples)
	return [z_input, labels]

In [0]:
def save_plot(examples, epoch, n=8):

  examples = (examples+1)/2.0

  for i in range(n*n):
    pyplot.subplot(n,n,1+i)
    pyplot.axis('off')
    pyplot.imshow(examples[i])

  pyplot.savefig('drive/My Drive/sagan_grids/generated_plote%05d.png' % (epoch+1))
  pyplot.close()

  model = InceptionV3(include_top=False, pooling='avg', input_shape=(299,299,3))
  (_,_),(images1, _) = load_data()
  images2 = examples
  shuffle(images1)
  images1 = images1[:1000]
  print('Loaded', images1.shape, images2.shape)
  images1 = images1.astype('float32')
  images1 = (images1 - 127.5) / 127.5
  images2 = images2.astype('float32')
  images1 = scale_images(images1, (299,299,3))
  images2 = scale_images(images2, (299,299,3))
  print('Scaled', images1.shape, images2.shape)
  fid = calculate_fid(model, images1, images2)
  print('FID: %.3f' % fid)
  

In [0]:
# load model
model = load_model('cgan_generator.h5',custom_objects={'conv':conv, 'tf':tf})
latent_points, labels = generate_latent_points(100, 100)
labels = asarray([x for _ in range(10) for x in range(10)])
X  = model.predict([latent_points, labels])
X = (X + 1) / 2.0
save_plot(X, 10)