In [None]:
from numpy import expand_dims
from numpy import mean
from numpy import ones
from numpy.random import randn
from numpy.random import randint
from keras.datasets.mnist import load_data
from keras import backend
from keras.optimizers import RMSprop
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers import Flatten
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import BatchNormalization
from keras.initializers import RandomNormal
from keras.constraints import Constraint
from matplotlib import pyplot

class ClipConstraint(Constraint):
	def __init__(self, clip_value):
		self.clip_value = clip_value

	def __call__(self, weights):
		return backend.clip(weights, -self.clip_value, self.clip_value)

	def get_config(self):
		return {'clip_value': self.clip_value}

def wasserstein_loss(y_true, y_pred):
	return backend.mean(y_true * y_pred)

def define_critic(in_shape=(28,28,1)):
	init = RandomNormal(stddev=0.02)
	const = ClipConstraint(0.01)
	model = Sequential()
	model.add(Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init, kernel_constraint=const, input_shape=in_shape))
	model.add(BatchNormalization())
	model.add(LeakyReLU(alpha=0.2))
	model.add(Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init, kernel_constraint=const))
	model.add(BatchNormalization())
	model.add(LeakyReLU(alpha=0.2))
	model.add(Flatten())
	model.add(Dense(1))
	opt = RMSprop(lr=0.00005)
	model.compile(loss=wasserstein_loss, optimizer=opt)
	return model

def define_generator(latent_dim):
	init = RandomNormal(stddev=0.02)
	model = Sequential()
	n_nodes = 128 * 7 * 7
	model.add(Dense(n_nodes, kernel_initializer=init, input_dim=latent_dim))
	model.add(LeakyReLU(alpha=0.2))
	model.add(Reshape((7, 7, 128)))
	model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
	model.add(BatchNormalization())
	model.add(LeakyReLU(alpha=0.2))
	model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
	model.add(BatchNormalization())
	model.add(LeakyReLU(alpha=0.2))
	model.add(Conv2D(1, (7,7), activation='tanh', padding='same', kernel_initializer=init))
	return model

def define_gan(generator, critic):
	for layer in critic.layers:
		if not isinstance(layer, BatchNormalization):
			layer.trainable = False
	model = Sequential()
	model.add(generator)
	model.add(critic)
	opt = RMSprop(lr=0.00005)
	model.compile(loss=wasserstein_loss, optimizer=opt)
	return model

def load_real_samples():
	(trainX, trainy), (_, _) = load_data()
	selected_ix = trainy == 7
	X = trainX[selected_ix]
	X = expand_dims(X, axis=-1)
	X = X.astype('float32')
	X = (X - 127.5) / 127.5
	return X

def generate_real_samples(dataset, n_samples):
	ix = randint(0, dataset.shape[0], n_samples)
	X = dataset[ix]
	y = -ones((n_samples, 1))
	return X, y

def generate_latent_points(latent_dim, n_samples):
	x_input = randn(latent_dim * n_samples)
	x_input = x_input.reshape(n_samples, latent_dim)
	return x_input

def generate_fake_samples(generator, latent_dim, n_samples):
	x_input = generate_latent_points(latent_dim, n_samples)
	X = generator.predict(x_input)
	y = ones((n_samples, 1))
	return X, y

def summarize_performance(step, g_model, latent_dim, n_samples=100):
	X, _ = generate_fake_samples(g_model, latent_dim, n_samples)
	X = (X + 1) / 2.0
	for i in range(10 * 10):
		pyplot.subplot(10, 10, 1 + i)
		pyplot.axis('off')
		pyplot.imshow(X[i, :, :, 0], cmap='gray_r')
	filename1 = 'generated_plot_%04d.png' % (step+1)
	pyplot.savefig(filename1)
	pyplot.close()
	filename2 = 'model_%04d.h5' % (step+1)
	g_model.save(filename2)
	print('>Saved: %s and %s' % (filename1, filename2))

def plot_history(d1_hist, d2_hist, g_hist):
	pyplot.plot(d1_hist, label='crit_real')
	pyplot.plot(d2_hist, label='crit_fake')
	pyplot.plot(g_hist, label='gen')
	pyplot.legend()
	pyplot.savefig('plot_line_plot_loss.png')
	pyplot.close()

def train(g_model, c_model, gan_model, dataset, latent_dim, n_epochs=1, n_batch=64, n_critic=5):
	bat_per_epo = int(dataset.shape[0] / n_batch)
	n_steps = bat_per_epo * n_epochs
	half_batch = int(n_batch / 2)
	c1_hist, c2_hist, g_hist = list(), list(), list()
	for i in range(n_steps):
		c1_tmp, c2_tmp = list(), list()
		for _ in range(n_critic):
			X_real, y_real = generate_real_samples(dataset, half_batch)
			c_loss1 = c_model.train_on_batch(X_real, y_real)
			c1_tmp.append(c_loss1)
			X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
			c_loss2 = c_model.train_on_batch(X_fake, y_fake)
			c2_tmp.append(c_loss2)
		c1_hist.append(mean(c1_tmp))
		c2_hist.append(mean(c2_tmp))
		X_gan = generate_latent_points(latent_dim, n_batch)
		y_gan = -ones((n_batch, 1))
		g_loss = gan_model.train_on_batch(X_gan, y_gan)
		g_hist.append(g_loss)
		print('>%d, c1=%.3f, c2=%.3f g=%.3f' % (i+1, c1_hist[-1], c2_hist[-1], g_loss))
		if (i+1) % bat_per_epo == 0:
			summarize_performance(i, g_model, latent_dim)
	plot_history(c1_hist, c2_hist, g_hist)

latent_dim = 50
critic = define_critic()
generator = define_generator(latent_dim)
gan_model = define_gan(generator, critic)
dataset = load_real_samples()
print(dataset.shape)
train(generator, critic, gan_model, dataset, latent_dim)

(6265, 28, 28, 1)
>1, c1=-1.936, c2=-0.065 g=0.957
>2, c1=-6.094, c2=0.014 g=-0.366
>3, c1=-9.343, c2=0.095 g=-1.433
>4, c1=-11.181, c2=0.163 g=-2.264
>5, c1=-13.765, c2=0.236 g=-3.486
>6, c1=-15.765, c2=0.303 g=-4.151
>7, c1=-16.637, c2=0.362 g=-5.168
>8, c1=-18.555, c2=0.423 g=-5.909
>9, c1=-19.361, c2=0.460 g=-7.048
>10, c1=-21.033, c2=0.516 g=-7.945
>11, c1=-21.381, c2=0.602 g=-9.004
>12, c1=-22.507, c2=0.686 g=-10.237
>13, c1=-23.633, c2=0.760 g=-11.010
>14, c1=-24.448, c2=0.864 g=-11.761


KeyboardInterrupt: ignored