In [None]:
# example of a wgan for generating handwritten digits
import numpy as np
from tensorflow.keras.datasets.mnist import load_data
from tensorflow.keras import backend
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras import models,layers
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.constraints import Constraint
import matplotlib.pyplot as plt

In [None]:
# clip model weights to a given hypercube
class ClipConstraint(Constraint):
	# set clip value when initialized
	def __init__(self, clip_value):
		self.clip_value = clip_value

	# clip model weights to hypercube
	def __call__(self, weights):
		return backend.clip(weights, -self.clip_value, self.clip_value)

	# get the config
	def get_config(self):
		return {'clip_value': self.clip_value}

In [None]:
# calculate wasserstein loss
def wasserstein_loss(y_true, y_pred):
	return backend.mean(y_true * y_pred)

In [None]:
# define the standalone critic model
def define_critic(in_shape=(28,28,1)):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# weight constraint
	const = ClipConstraint(0.01)
	# define model
	model = models.Sequential()
	# downsample to 14x14
	model.add(layers.Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init, kernel_constraint=const, input_shape=in_shape))
	model.add(layers.BatchNormalization())
	model.add(layers.LeakyReLU(alpha=0.2))
	# downsample to 7x7
	model.add(layers.Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init, kernel_constraint=const))
	model.add(layers.BatchNormalization())
	model.add(layers.LeakyReLU(alpha=0.2))
	# scoring, linear activation
	model.add(layers.Flatten())
	model.add(layers.Dense(1))
	# compile model
	opt = RMSprop(lr=0.00005)
	model.compile(loss=wasserstein_loss, optimizer=opt)
	return model

In [None]:
# define the standalone generator model
def define_generator(latent_dim):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# define model
	model = models.Sequential()
	# foundation for 7x7 image
	n_nodes = 128 * 7 * 7
	model.add(layers.Dense(n_nodes, kernel_initializer=init, input_dim=latent_dim))
	model.add(layers.LeakyReLU(alpha=0.2))
	model.add(layers.Reshape((7, 7, 128)))
	# upsample to 14x14
	model.add(layers.Conv2DTranspose(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
	model.add(layers.BatchNormalization())
	model.add(layers.LeakyReLU(alpha=0.2))
	# upsample to 28x28
	model.add(layers.Conv2DTranspose(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
	model.add(layers.BatchNormalization())
	model.add(layers.LeakyReLU(alpha=0.2))
	# output 28x28x1
	model.add(layers.Conv2D(1, (7,7), activation='tanh', padding='same', kernel_initializer=init))
	return model

In [None]:
# define the combined generator and critic model, for updating the generator
def define_gan(generator, critic):
	# make weights in the critic not trainable
	critic.trainable = False
	# connect them
	model = models.Sequential()
	# add generator
	model.add(generator)
	# add the critic
	model.add(critic)
	# compile model
	opt = RMSprop(lr=0.00005)
	model.compile(loss=wasserstein_loss, optimizer=opt)
	return model

In [None]:
# load images
def load_real_samples():
	# load dataset
	(trainX, trainy), (_, _) = load_data()
	# select all of the examples for a given class
	selected_ix = trainy == 7
	X = trainX[selected_ix]
	# expand to 3d, e.g. add channels
	X = np.expand_dims(X, axis=-1)
	# convert from ints to floats
	X = X.astype('float32')
	# scale from [0,255] to [-1,1]
	X = (X - 127.5) / 127.5
	return X


In [None]:
# select real samples
def generate_real_samples(dataset, n_samples):
	# choose random instances
	ix = np.random.randint(0, dataset.shape[0], n_samples)
	# select images
	X = dataset[ix]
	# generate class labels, -1 for 'real'
	y = -np.ones((n_samples, 1))
	return X, y

# generate points in latent space as input for the generator
def generate_latent_points(latent_dim, n_samples):
	# generate points in the latent space
	x_input = np.random.randn(latent_dim * n_samples)
	# reshape into a batch of inputs for the network
	x_input = x_input.reshape(n_samples, latent_dim)
	return x_input

# use the generator to generate n fake examples, with class labels
def generate_fake_samples(generator, latent_dim, n_samples):
	# generate points in latent space
	x_input = generate_latent_points(latent_dim, n_samples)
	# predict outputs
	X = generator.predict(x_input)
	# create class labels with 1.0 for 'fake'
	y = np.ones((n_samples, 1))
	return X, y



In [None]:
# generate samples and save as a plot and save the model
def summarize_performance(step, g_model, latent_dim, n_samples=100):
	# prepare fake examples
	X, _ = generate_fake_samples(g_model, latent_dim, n_samples)
	# scale from [-1,1] to [0,1]
	X = (X + 1) / 2.0
	# plot images
	for i in range(10 * 10):
		# define subplot
		plt.subplot(10, 10, 1 + i)
		# turn off axis
		plt.axis('off')
		# plot raw pixel data
		plt.imshow(X[i, :, :, 0], cmap='gray_r')
	# save plot to file
	filename1 = 'generated_plot_%04d.png' % (step+1)
	plt.savefig(filename1)
	plt.close()
	# save the generator model
	filename2 = 'model_%04d.h5' % (step+1)
	g_model.save(filename2)
	print('>Saved: %s and %s' % (filename1, filename2))

# create a line plot of loss for the gan and save to file
def plot_history(d1_hist, d2_hist, g_hist):
	# plot history
	plt.plot(d1_hist, label='crit_real')
	plt.plot(d2_hist, label='crit_fake')
	plt.plot(g_hist, label='gen')
	plt.legend()
	plt.savefig('plot_line_plot_loss.png')
	plt.close()



In [None]:
# train the generator and critic
def train(g_model, c_model, gan_model, dataset, latent_dim, n_epochs=20, n_batch=64, n_critic=5):
	# calculate the number of batches per training epoch
	bat_per_epo = int(dataset.shape[0] / n_batch)
	# calculate the number of training iterations
	n_steps = bat_per_epo * n_epochs
	# calculate the size of half a batch of samples
	half_batch = int(n_batch / 2)
	# lists for keeping track of loss
	c1_hist, c2_hist, g_hist = list(), list(), list()
	# manually enumerate epochs
	for i in range(n_steps):
		# update the critic more than the generator
		c1_tmp, c2_tmp = list(), list()
		for _ in range(n_critic):
			# get randomly selected 'real' samples
			X_real, y_real = generate_real_samples(dataset, half_batch)
			# update critic model weights
			c_loss1 = c_model.train_on_batch(X_real, y_real)
			c1_tmp.append(c_loss1)
			# generate 'fake' examples
			X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
			# update critic model weights
			c_loss2 = c_model.train_on_batch(X_fake, y_fake)
			c2_tmp.append(c_loss2)
		# store critic loss
		c1_hist.append(np.mean(c1_tmp))
		c2_hist.append(np.mean(c2_tmp))
		# prepare points in latent space as input for the generator
		X_gan = generate_latent_points(latent_dim, n_batch)
		# create inverted labels for the fake samples
		y_gan = -np.ones((n_batch, 1))
		# update the generator via the critic's error
		g_loss = gan_model.train_on_batch(X_gan, y_gan)
		g_hist.append(g_loss)
		# summarize loss on this batch
		print('>%d, c1=%.3f, c2=%.3f g=%.3f' % (i+1, c1_hist[-1], c2_hist[-1], g_loss))
		# evaluate the model performance every 'epoch'
		if (i+1) % bat_per_epo == 0:
			summarize_performance(i, g_model, latent_dim)
	# line plots of loss
	plot_history(c1_hist, c2_hist, g_hist)



In [None]:
# size of the latent space
latent_dim = 50
# create the critic
critic = define_critic()
# create the generator
generator = define_generator(latent_dim)
# create the gan
gan_model = define_gan(generator, critic)
# load image data
dataset = load_real_samples()
print(dataset.shape)
# train model
train(generator, critic, gan_model, dataset, latent_dim)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
(6265, 28, 28, 1)
>1, c1=-2.038, c2=0.020 g=-0.000
>2, c1=-6.433, c2=0.073 g=-0.008
>3, c1=-9.513, c2=0.089 g=-0.014
>4, c1=-11.844, c2=0.118 g=-0.019
>5, c1=-13.771, c2=0.146 g=-0.022
>6, c1=-15.868, c2=0.177 g=-0.029
>7, c1=-17.666, c2=0.248 g=-0.041
>8, c1=-18.468, c2=0.351 g=-0.052
>9, c1=-19.809, c2=0.422 g=-0.067
>10, c1=-20.846, c2=0.517 g=-0.085
>11, c1=-21.851, c2=0.641 g=-0.111
>12, c1=-22.854, c2=0.734 g=-0.131
>13, c1=-23.784, c2=0.818 g=-0.157
>14, c1=-24.568, c2=0.904 g=-0.187
>15, c1=-25.386, c2=0.936 g=-0.215
>16, c1=-25.917, c2=1.030 g=-0.253
>17, c1=-26.578, c2=1.067 g=-0.292
>18, c1=-27.498, c2=1.076 g=-0.334
>19, c1=-28.273, c2=1.122 g=-0.383
>20, c1=-28.273, c2=1.071 g=-0.441
>21, c1=-28.900, c2=1.030 g=-0.505
>22, c1=-29.805, c2=0.949 g=-0.564
>23, c1=-30.152, c2=0.830 g=-0.640
>24, c1=-30.238, c2=0.659 g=-0.701
>25, c1=-30.752, c2=0.377 g=-0.807
>26, c1=-31.926, c2=0.045 g

In [None]:
noise_input = np.random.rand(50,50)
prediction = generator.predict(noise_input)
print(prediction.shape)
plt.figure(figsize=(20,15))
for i in range(50):
  plt.subplot(10,5,i+1)
  plt.imshow(prediction[i,:,:,0],cmap='gray')
  plt.axis('off')
plt.show()