<a href="https://colab.research.google.com/github/Archangel212/GAN_jupyter_notebooks/blob/master/WGAN_BATIK_Dataset_64by64_UNet_architecture.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [34]:
# example of a wgan for generating handwritten digits
import sys
sys.path.append('/content/drive/My Drive')

from numpy import expand_dims
from numpy import mean
from numpy import ones
from numpy.random import randn
from numpy.random import randint
from keras.datasets.mnist import load_data
from keras import backend
from keras.optimizers import RMSprop
from keras.models import Sequential,save_model,load_model
from keras.layers import Dense
from keras.layers import Dropout
from keras.layers import Reshape
from keras.layers import Flatten
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import BatchNormalization
from keras.initializers import RandomNormal
from keras.constraints import Constraint
from keras.utils import plot_model
from matplotlib import pyplot
import os
import time
import numpy as np
import re
import h5py
import pandas as pd
from utils.save_model_summary import save_model_summary
from utils.trim_csv import trim_csv
import math
import tensorflow as tf 

np.random.seed(1)
tf.random.set_seed(2)

os.chdir("/content/drive/My Drive/WGAN/BATIK_Dataset_64by64_UNet_architecture")
os.getcwd()

'/content/drive/My Drive/WGAN/BATIK_Dataset_64by64_UNet_architecture'

In [2]:
# clip model weights to a given hypercube
class ClipConstraint(Constraint):
	# set clip value when initialized
	def __init__(self, clip_value):
		self.clip_value = clip_value

	# clip model weights to hypercube
	def __call__(self, weights):
		return backend.clip(weights, -self.clip_value, self.clip_value)

	# get the config
	def get_config(self):
		return {'clip_value': self.clip_value}

# calculate wasserstein loss
def wasserstein_loss(y_true, y_pred):
	return backend.mean(y_true * y_pred)

# define the standalone critic model
def define_critic(in_shape=(64,64,3)):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# weight constraint
	const = ClipConstraint(0.01)
	# define model
	model = Sequential()

	# input layer
	model.add(Conv2D(64, (3,3), padding='same', input_shape=in_shape, kernel_initializer=init, kernel_constraint=const))
	model.add(LeakyReLU(alpha=0.2))
	model.add(BatchNormalization(axis=-1))
	# downsample to 32x32
	model.add(Conv2D(64, (3,3), strides=(2,2), padding='same', kernel_initializer=init, kernel_constraint=const))
	model.add(LeakyReLU(alpha=0.2))
	model.add(BatchNormalization(axis=-1))
	# downsample to 16x16
	model.add(Conv2D(128, (3,3), strides=(2,2), padding='same',kernel_initializer=init, kernel_constraint=const))
	model.add(LeakyReLU(alpha=0.2))
	model.add(BatchNormalization(axis=-1))
	# downsample to 8x8
	model.add(Conv2D(256, (3,3), strides=(2,2), padding='same',kernel_initializer=init, kernel_constraint=const))
	model.add(LeakyReLU(alpha=0.2))
	model.add(BatchNormalization(axis=-1))
	# downsample to 4x4
	model.add(Conv2D(512, (3,3), strides=(2,2), padding='same',kernel_initializer=init, kernel_constraint=const))
	model.add(LeakyReLU(alpha=0.2))
	model.add(BatchNormalization(axis=-1))
	# downsample to 2x2
	model.add(Conv2D(512, (3,3), strides=(2,2), padding='same',kernel_initializer=init, kernel_constraint=const))
	model.add(LeakyReLU(alpha=0.2))
	model.add(BatchNormalization(axis=-1))
	# downsample to 1x1
	model.add(Conv2D(512, (3,3), strides=(2,2), padding='same',kernel_initializer=init, kernel_constraint=const))
	model.add(LeakyReLU(alpha=0.2))
	model.add(BatchNormalization(axis=-1))
	# scoring, linear activation
	model.add(Flatten())
	model.add(Dense(1))
	# compile model
	opt = RMSprop(lr=0.00005)
	model.compile(loss=wasserstein_loss, optimizer=opt)
	return model

# define the standalone generator model
def define_generator(latent_dim):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# define model
	model = Sequential()
	# foundation for 2x2 image
	n_nodes = 512 * 2 * 2
	model.add(Dense(n_nodes, kernel_initializer=init, input_dim=latent_dim))
	model.add(LeakyReLU(alpha=0.2))
	model.add(BatchNormalization(axis=-1))
	model.add(Reshape((2, 2, 512)))
	# upsample to 4x4
	model.add(Conv2DTranspose(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
	model.add(LeakyReLU(alpha=0.2))
	model.add(BatchNormalization(axis=-1))
	# upsample to 8x8
	model.add(Conv2DTranspose(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
	model.add(LeakyReLU(alpha=0.2))
	model.add(BatchNormalization(axis=-1))
	# upsample to 16x16
	model.add(Conv2DTranspose(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
	model.add(LeakyReLU(alpha=0.2))
	model.add(BatchNormalization(axis=-1))
	# upsample to 32x32
	model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same',kernel_initializer=init))
	model.add(LeakyReLU(alpha=0.2))
	model.add(BatchNormalization(axis=-1))
	# upsample to 64x64
	model.add(Conv2DTranspose(64, (4,4), strides=(2,2), padding='same',kernel_initializer=init))
	model.add(LeakyReLU(alpha=0.2))
	model.add(BatchNormalization(axis=-1))
	# output 64x64x3
	model.add(Conv2D(3, (3,3), activation='tanh', padding='same', kernel_initializer=init))
	return model

# define the combined generator and critic model, for updating the generator
def define_gan(generator, critic):
	# make weights in the critic not trainable
	critic.trainable = False
	# connect them
	model = Sequential()
	# add generator
	model.add(generator)
	# add the critic
	model.add(critic)
	# compile model
	opt = RMSprop(lr=0.00005)
	model.compile(loss=wasserstein_loss, optimizer=opt)
	return model


def load_dataset(ds_path):
  with h5py.File(ds_path,"r") as f:
    dataset = f["Batik"]
    dataset = np.copy(dataset)
  return dataset


	
# load images
def load_real_samples():
	# load cifar10 dataset
	ds_path = "/content/drive/My Drive/Batik_Datasets/pinterest_version/batik_dataset_64by64.hdf5"
	trainX = load_dataset(ds_path)
	# convert from unsigned ints to floats
	X = trainX.astype('float32')
	# scale from [0,255] to [-1,1]
	X = (X - 127.5) / 127.5
	return X

def noisy_labels(y, p_flip):
	# ix = np.random.choice(y.shape[0], size=int(y.shape[0]*p_flip), replace=False)
	# y[ix] = 1 - y[ix]
	n_select = int(p_flip * y.shape[0])
	# choose labels to flip
	flip_ix = np.random.choice([i for i in range(y.shape[0])], size=n_select,replace=False)
	# invert the labels in place
	y[flip_ix] =  -1 *  y[flip_ix]
	return y

# select real samples
def generate_real_samples(dataset, n_samples, label_noising=True, p_flip=0.05):
	# choose random instances
	ix = randint(0, dataset.shape[0], n_samples)
	# select images
	X = dataset[ix]
	# generate class labels, -1 for 'real'
	y = -ones((n_samples, 1))
 
	if label_noising:
		y = noisy_labels(y, p_flip) 
	
	return X, y

# generate points in latent space as input for the generator
def generate_latent_points(latent_dim, n_samples):
	# generate points in the latent space
	x_input = randn(latent_dim * n_samples)
	# reshape into a batch of inputs for the network
	x_input = x_input.reshape(n_samples, latent_dim)
	return x_input

# use the generator to generate n fake examples, with class labels
def generate_fake_samples(generator, latent_dim, n_samples, label_noising=True, p_flip=0.05):
	# generate points in latent space
	x_input = generate_latent_points(latent_dim, n_samples)
	# predict outputs
	X = generator.predict(x_input)
	# create class labels with 1.0 for 'fake'
	y = ones((n_samples, 1))
 
	if label_noising:
		y = noisy_labels(y, p_flip)
	
	return X, y

# generate samples and save as a plot and save the model
def summarize_performance(epoch,c_model, g_model, gan_model, latent_dim,n_samples=100):

	# prepare fake examples
	X, _ = generate_fake_samples(g_model, latent_dim, n_samples, label_noising=False)
	# scale from [-1,1] to [0,1]
	X = (X + 1) / 2.0
	# plot images
	for i in range(7 * 7):
		# define subplot
		pyplot.subplot(7, 7, 1 + i)
		# turn off axis
		pyplot.axis('off')
		# plot raw pixel data
		pyplot.imshow(X[i])
	# save plot to file
	figure_plot = os.path.join('figure_plots','generated_plot_%03d.png' % (epoch))
	pyplot.savefig(figure_plot)
	pyplot.close()
 
	#save critic model
	critic_model_weights_fname = os.path.join('model_checkpoints','critic_model_weights.h5')
	c_model.save_weights(critic_model_weights_fname)

	generator_fname = os.path.join('model_checkpoints', 'generator_model_%03d.h5' % (epoch))
	save_model(g_model, generator_fname)

	gan_model_weights_fname = os.path.join('model_checkpoints', 'GAN_model_weights.h5')
	gan_model.save_weights(gan_model_weights_fname)

	gan_optimizer_weights_fname = os.path.join('model_checkpoints', 'GAN_optimizer_weights.npy')
	gan_optimizer_weights = gan_model.optimizer.get_weights()
	np.save(gan_optimizer_weights_fname, gan_optimizer_weights, allow_pickle=True)

	print(f'Model_epoch_{epoch} saved')

# create a line plot of loss for the gan and save to file
def plot_history(df_hist):
	# plot history
	pyplot.plot(df_hist["critic_loss_real"], label='critic_real')
	pyplot.plot(df_hist["critic_loss_fake"], label='critic_fake')
	pyplot.plot(df_hist["gan_loss"], label='gen')
	pyplot.xlabel("number of iterations")
	pyplot.ylabel("loss")
	pyplot.legend()
	pyplot.savefig('plot_line_plot_loss.png')
	pyplot.close()


# train the generator and critic
def train(g_model, c_model, gan_model, dataset, latent_dim, csv_path=None, n_epochs=10, n_batch=64, n_critic=5, initial_epoch=0):
	# calculate the number of batches per training epoch
	bat_per_epo = int(dataset.shape[0] / n_batch)
	initial_step = initial_epoch*bat_per_epo

	print(dataset.shape,"bat_per_epo: ", bat_per_epo)
	# calculate the number of training iterations
	n_steps = bat_per_epo * n_epochs
	# calculate the size of half a batch of samples
	half_batch = int(n_batch / 2)
	# lists for keeping track of loss
	# c1_hist, c2_hist, g_hist = list(), list(), list()
	df_hist = pd.read_csv(csv_path,index_col=None)
	list_hist = df_hist.values.tolist()
	# manually enumerate epochs
	for i in range(initial_step, n_steps):
		# update the critic more than the generator
		epoch = math.ceil((i+1)/bat_per_epo)
		steps = "%d/%d" % ((i+1)%bat_per_epo if (i+1)%bat_per_epo != 0 else bat_per_epo ,bat_per_epo)  

		c1_tmp, c2_tmp = list(), list()
		for _ in range(n_critic):
			# get randomly selected 'real' samples
			X_real, y_real = generate_real_samples(dataset, half_batch)
			# update critic model weights
			c_loss1 = c_model.train_on_batch(X_real, y_real)
			c1_tmp.append(c_loss1)
			# generate 'fake' examples
			X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
			# update critic model weights
			c_loss2 = c_model.train_on_batch(X_fake, y_fake)
			c2_tmp.append(c_loss2)
		# prepare points in latent space as input for the generator
		X_gan = generate_latent_points(latent_dim, n_batch)
		# create inverted labels for the fake samples
		y_gan = -ones((n_batch, 1))
		# update the generator via the critic's error
		g_loss = gan_model.train_on_batch(X_gan, y_gan)
		# g_hist.append(g_loss)
		list_hist.append([epoch, steps, mean(c1_tmp), mean(c2_tmp),g_loss])
		# df_hist = df_hist.append({
		# 		"epoch"	: epoch,
		# 		"steps" : steps,
		# 		"critic_loss_real" : mean(c1_tmp),
		# 		"critic_loss_fake" : mean(c2_tmp),
		# 		"gan_loss" : g_loss
		# }, ignore_index=True)


		# summarize loss on this batch
		print('>%d, %s, c1=%.3f, c2=%.3f g=%.3f' % (epoch, steps, mean(c1_tmp), mean(c2_tmp),g_loss))
		# evaluate the model performance every '2 epoch'
		if (i+1) % (bat_per_epo*2) == 0 or (i+1) == bat_per_epo:
			summarize_performance(epoch, c_model, g_model, gan_model, latent_dim)
			df_hist = pd.DataFrame(list_hist, columns=["epoch","steps","critic_loss_real",
																							"critic_loss_fake","gan_loss"])
			plot_history(df_hist)
			df_hist.to_csv(csv_path, mode="w", index=False)


In [3]:

def train_model(initial_epoch=0, latent_dim=250):
  csv_path = "loss.csv"
  if not os.path.exists(csv_path): 
    with open(csv_path,'a+') as f:
      f.write('epoch,steps,critic_loss_real,critic_loss_fake,gan_loss')
  os.makedirs('figure_plots', exist_ok=True)
  os.makedirs('model_checkpoints', exist_ok=True)
  os.makedirs('model_summaries', exist_ok=True)
  
  #get last epoch from  generator last checkpoint
  generator_models = list(filter(lambda x: x,[re.findall('[0-9]+',s.split('.')[0]) for s in os.listdir('model_checkpoints')]))
  last_epoch = max([int(g[0]) for g in generator_models]) if len(generator_models) != 0 else 0

  if last_epoch == 0:
    print("Start training ... ")
    # create the critic
    c_model = define_critic()
    # create the generator
    g_model = define_generator(latent_dim)
    # create the gan
    gan_model = define_gan(g_model, c_model)

  else:
    trim_csv(csv_path, last_epoch)
    
    model_paths = ["critic_model_weights.h5","GAN_model_weights.h5","GAN_optimizer_weights.npy"]
    model_filenames = [os.path.join("model_checkpoints", p)  for p in model_paths] 

    g_model = load_model(os.path.join("model_checkpoints","generator_model_%03d.h5" % (last_epoch)))
    c_model = define_critic()
    c_model.load_weights(model_filenames[0])

    gan_model = define_gan(g_model, c_model)
    gan_model.load_weights(model_filenames[1])
    gan_model._make_train_function()

    gan_model_optimizer_weights = np.load(model_filenames[2], allow_pickle=True).tolist()
    gan_model.optimizer.set_weights(gan_model_optimizer_weights)

    initial_epoch = last_epoch    
    print(f"Last trained model at epoch : {last_epoch}")
    print("Resume Training ...")

  # load image data
  dataset = load_real_samples()
  
  models = [g_model,c_model,gan_model]
  filenames = ['Generator_model.png','Critic_model.png','GAN_model.png']
  for (model,fn) in zip(models,filenames):
    plot_model(model, to_file=fn, show_shapes=True, show_layer_names=True)
    save_model_summary(model, os.path.join('model_summaries', fn))

  start = time.time()
  train(g_model, c_model, gan_model, dataset, latent_dim, n_epochs=1200,
        csv_path=csv_path,initial_epoch=initial_epoch)
  print(f"Elapsed time: {time.time() - start} seconds") 

In [33]:
train_model()

csv trimed at epoch > 8




Last trained model at epoch : 8
Resume Training ...


  'Discrepancy between trainable weights and collected trainable'
  'Discrepancy between trainable weights and collected trainable'


(14400, 64, 64, 3) bat_per_epo:  225
>9, 1/225, c1=-1.027, c2=-1.093 g=-0.011
>9, 2/225, c1=2.894, c2=-0.236 g=-0.011
>9, 3/225, c1=-0.436, c2=-1.742 g=-0.012
>9, 4/225, c1=0.664, c2=-1.013 g=-0.012
>9, 5/225, c1=-1.705, c2=-0.009 g=-0.012
>9, 6/225, c1=-0.880, c2=2.394 g=-0.012
>9, 7/225, c1=0.752, c2=-1.621 g=-0.012
>9, 8/225, c1=-1.392, c2=2.445 g=-0.012
>9, 9/225, c1=0.519, c2=-2.208 g=-0.012
>9, 10/225, c1=-1.534, c2=-0.011 g=-0.012
>9, 11/225, c1=1.774, c2=-0.338 g=-0.012
>9, 12/225, c1=-1.901, c2=-0.814 g=-0.012
>9, 13/225, c1=-1.155, c2=-0.961 g=-0.012
>9, 14/225, c1=1.891, c2=3.502 g=-0.012
>9, 15/225, c1=0.262, c2=-1.994 g=-0.012
>9, 16/225, c1=0.744, c2=0.429 g=-0.012
>9, 17/225, c1=-0.434, c2=0.427 g=-0.012
>9, 18/225, c1=0.177, c2=0.590 g=-0.012
>9, 19/225, c1=-0.430, c2=2.786 g=-0.012
>9, 20/225, c1=-0.792, c2=1.907 g=-0.012
>9, 21/225, c1=1.207, c2=0.173 g=-0.012
>9, 22/225, c1=0.118, c2=0.844 g=-0.012
>9, 23/225, c1=0.512, c2=-0.889 g=-0.012
>9, 24/225, c1=-1.590, c2=-0

KeyboardInterrupt: ignored

In [None]:
df_hist = pd.read_csv("loss.csv", index_col=None)
list_hist = df_hist.values.tolist()
list_hist

In [None]:
last_step = max([int(re.findall('[0-9]+',s.split('.')[0])[0]) for s in os.listdir('model_checkpoints')])
last_epoch = max([int(re.findall('[0-9]+',s.split('.')[0])[0]) for s in os.listdir('model_checkpoints')])

model_paths = ["generator_model_%03d.h5","critic_model_weights_%03d.h5",
              "GAN_model_weights_%03d.h5","GAN_optimizer_weights_%03d.npy"]
model_filenames = [os.path.join("model_checkpoints", p % (last_epoch))  for p in model_paths]  
print(model_filenames, last_epoch)

In [None]:
g_model = load_model(os.path.join('model_checkpoints','generator_model_015.h5'))
latent_dim = 250
n_samples = 100
X, _ = generate_fake_samples(g_model, latent_dim, n_samples)
# scale from [-1,1] to [0,1]
X = (X + 1) / 2.0
# plot images
print(X.shape)
for i in range(7 * 7):
  # define subplot
  pyplot.subplot(7, 7, 1 + i)
  # turn off axis
  pyplot.axis('off')
  # plot raw pixel data
  pyplot.imshow(X[i,:,:,:])

In [None]:
import math
df = pd.read_csv('loss.csv', index_col=None)
# print(df[df.columns].loc[df['steps']<=225])
# df.insert(0, 'epoch', '')
# df['epoch'] = df['steps'].map(lambda x: math.ceil(x/225))
# df['steps'] = df['steps'].map(lambda x: "%d/%d" % ((x%225) if (x%225)!=0 else 225, 225))
# df.to_csv('loss.csv', index=False)
# df = df.loc[df['epoch'] <= 12 ]
# df = df.drop(df.index[-1])
# df.to_csv('loss.csv', index=False)
df

In [None]:
# os.chdir('figure_plots')
# print(os.getcwd())
# for plot in os.listdir('.'):
#   steps = int(re.findall('[0-9]+',plot)[0])
#   new_fname = "generated_plot_%03d.png" % (steps/225)
#   os.rename(plot, new_fname)
#   print(plot,new_fname)


In [None]:
generator_models = filter(lambda x: x,[re.findall('[0-9]+',s.split('.')[0]) for s in os.listdir('model_checkpoints')])
last_epoch = max([int(g[0]) for g in generator_models])
last_epoch

In [None]:
# a = [a*225 for a in range(1,10)]
# for i in a:
#   if i % (225*2) == 0:
#     print(i)