In [11]:
# example of a wgan for generating handwritten digits
import numpy as np
from numpy import expand_dims
from numpy import mean
from numpy import ones
from numpy.random import rand
from numpy.random import randn
from numpy.random import randint
from keras.datasets.mnist import load_data
from keras import backend
from keras.optimizers import RMSprop
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers import Flatten
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import ReLU
from keras.layers import BatchNormalization
from keras.initializers import RandomNormal
from keras.constraints import Constraint
from matplotlib import pyplot
import tensorflow as tf
from tensorflow.keras import Model

# constants
# output dim of vgg16
image_features_dim = 4096
# output dim of doc2vec
text_features_dim = 4096
# output dim of word2vec
class_features_dim = 4096
# number of units in critic's layer 1
critic_units_layer1 = 4096
# number of units in generator layer1, 2
generator_units_layer1 = 4096
generator_units_layer2 = 4096
# number of units in regressor layer1, 2, 3
regressor_units_layer1 = 4096
regressor_units_layer2 = 4096
regressor_units_layer3 = 300
# size of the latent space
latent_dim = 4096 # for now it should be same as class_emb_dim

# clip model weights to a given hypercube
class ClipConstraint(Constraint):
	# set clip value when initialized
	def __init__(self, clip_value):
		self.clip_value = clip_value

	# clip model weights to hypercube
	def __call__(self, weights):
		return backend.clip(weights, -self.clip_value, self.clip_value)

	# get the config
	def get_config(self):
		return {'clip_value': self.clip_value}

# calculate wasserstein loss
def wasserstein_loss(y_true, y_pred):
  # y_true.shape = (batch_size x 1)
  # value of y_true is either 1 or -1 depending
  # upon fake or real sample. y_pred is output of 
  # the critic module
  return backend.mean(y_true * y_pred)

# generate points in latent space as input for the generator
def generate_latent_points(latent_dim, n_samples):
  # generate points in the latent space
  x_input = randn(latent_dim * n_samples)
  # reshape into a batch of inputs for the network
  x_input = x_input.reshape(n_samples, latent_dim)
  return x_input

# define the standalone critic model
def define_critic(in_shape=(image_features_dim, )):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# weight constraint
	const = ClipConstraint(0.01)
	# define model
	model = Sequential()
	# downsample to 14x14
	model.add(Dense(critic_units_layer1, kernel_initializer=init, kernel_constraint=const, input_shape=in_shape))
	model.add(LeakyReLU(alpha=0.2))
	# scoring, linear activation
	model.add(Dense(1))
	# compile model
	opt = RMSprop(lr=0.00005)       # TODO: change it to adam
	model.compile(loss=wasserstein_loss, optimizer=opt)
	return model

# define the standalone generator model
def define_generator(latent_dim):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# define model
	model = Sequential()
  # layer 1
	model.add(Dense(generator_units_layer1, kernel_initializer=init, input_dim=latent_dim))
	model.add(LeakyReLU(alpha=0.2))
	# layer 2
	model.add(Dense(generator_units_layer2, kernel_initializer=init))
	model.add(ReLU())
	return model

# define the standalone regressor model
def define_regressor(in_shape=image_features_dim):
  # weight initialization
  init = RandomNormal(stddev=0.02)
  # define model
  model = Sequential()
  # layer 1
  model.add(Dense(regressor_units_layer1, kernel_initializer=init, input_dim=in_shape))
  model.add(ReLU())
  # layer 2
  model.add(Dense(regressor_units_layer2, kernel_initializer=init))
  model.add(ReLU())
  # layer 3
  model.add(Dense(regressor_units_layer3, kernel_initializer=init))
  model.add(ReLU())
  # model.compile will come in the final overall model
  return model

class mymodel(Model):
  def __init__(self, latent_dim):
    super(mymodel, self).__init__()
    self.g1 = define_generator(latent_dim)
    self.g2 = define_generator(latent_dim)
    self.c1 = define_critic()
    self.c2 = define_critic()
    self.r1 = define_regressor()
    self.r2 = define_regressor()

  def call(self, _in):
    in1, in2 = _in
    g1c1 = self.g1(in1)
    g1c1 = self.c1(g1c1)

    g1r1 = self.g1(in1)
    g1r1 = self.r1(g1r1)

    g2r2 = self.g2(in2)
    g2r2 = self.r2(g2r2)

    g2c2 = self.g2(in2)
    g2c2 = self.c2(g2c2)

    return g1c1, g1r1, g2r2, g2c2

# define the combined generator and critic model, for updating the generator
def define_gan(generator, critic):
	# make weights in the critic not trainable
	for layer in critic.layers:
		if not isinstance(layer, BatchNormalization):
			layer.trainable = False
	# path 1
	model = Sequential()
	model.add(generator)
	model.add(critic)
 
	# compile model
	# opt = RMSprop(lr=0.00005)       #TODO: change it to adam
	# model.compile(loss=wasserstein_loss, optimizer=opt)
	return model

# load images
def load_real_samples():
  """ This function should return (nsamples x fetures) training data
      for both image and text modality and corresponding class embeddings: 
      (nsamples x class_emb_size)
  """
  return rand(1000, 4096), rand(1000, 4096), np.zeros((1000, 4096))

# select real samples
def generate_real_samples(image_dataset, text_dataset, CE, n_samples):
  """
    This function should return three things:
    Xv : n_samples * features
    Xt : n_samples * features
    y : n_samples * 1
    CE: n_samples * class_embedding_dimension
  """
  # choose random instances
  ix = randint(0, image_dataset.shape[0], n_samples)
  # select images, text
  Xv = image_dataset[ix]
  Xt = text_dataset[ix]
  # generate class labels, -1 for 'real'
  y = -ones((n_samples, 1))
  ce = CE[ix]
  return Xv, Xt, y, ce

# use the generator to generate n fake examples, with class labels
def generate_fake_samples(generator1, generator2, ce, latent_dim, n_samples):
  # generate points in latent space for image generator
  x_input_image = generate_latent_points(latent_dim, n_samples)
  x_input_image += ce # adding the noise with class embeddings
  # generate points in latent space for text generator
  x_input_text = generate_latent_points(latent_dim, n_samples)
  x_input_text += ce # noise + class_emb
  # predict outputs
  Xv = generator1.predict(x_input_image)
  Xt = generator2.predict(x_input_text)
  # create class labels with 1.0 for 'fake'
  y = ones((n_samples, 1))
  return Xv, Xt, y

loss_objects = [wasserstein_loss, tf.nn.l2_loss, tf.keras.losses.KLD]
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
overall_model = mymodel(latent_dim)

def train_step(inputs, y_true): # (inputs = (n1+ce, n2+ce))
  with tf.GradientTape() as tape:
    (g1c1, g1r1, g2r2, g2c2) = overall_model(inputs)
    regressor_loss = loss_objects[1](g1r1, g2r2)
    generator1_loss = loss_objects[0](y_true, g1c1)
    generator2_loss = loss_objects[0](y_true, g2c2)
    losses = [regressor_loss, generator1_loss, generator2_loss]
  
  gradients = tape.gradient(losses, overall_model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, overall_model.trainable_variables))
  return (g1c1, g1r1, g2r2, g2c2)

# train the generator and critic
def train(g1, g2, c1, c2, r1, r2, image_dataset, text_dataset, CE, latent_dim, n_epochs=10, n_batch=64, n_critic=5):
  bat_per_epo = int(image_dataset.shape[0] / n_batch)
  n_steps = bat_per_epo * n_epochs
  half_batch = int(n_batch / 2)
  c11_hist, c12_hist, c21_hist, c22_hist= list(), list(), list(), list()
  g1_loss, g2_loss = list(), list()

  for i in range(n_steps):
    # stores the loss values for both critics individually for real and fake data
    c11_tmp, c12_tmp, c21_tmp, c22_tmp = list(), list(), list(), list()
    for _ in range(n_critic):
      # training both critics on real data
      Xv_real, Xt_real, y_real, ce = generate_real_samples(image_dataset, text_dataset, CE, half_batch)
      c_loss1 = c1.train_on_batch(Xv_real, y_real)
      c11_tmp.append(c_loss1)
      c_loss1 = c2.train_on_batch(Xt_real, y_real)
      c21_tmp.append(c_loss1)

      # training both critics on fake data
      Xv_fake, Xt_fake, y_fake = generate_fake_samples(g1, g2, ce, latent_dim, half_batch) #TODO: how/where to use class_embs
      c_loss1 = c1.train_on_batch(Xv_fake, y_fake)
      c12_tmp.append(c_loss1)
      c_loss1 = c2.train_on_batch(Xt_fake, y_fake)
      c22_tmp.append(c_loss1)
    # store critic loss
    c11_hist.append(mean(c11_tmp))
    c12_hist.append(mean(c12_tmp))
    c21_hist.append(mean(c21_tmp))
    c22_hist.append(mean(c22_tmp))

    # prepare points in latent space as input for the generator
    _, _, _, ce = generate_real_samples(image_dataset, text_dataset, CE, n_batch)
    X_text = generate_latent_points(latent_dim, n_batch) + ce
    X_image = generate_latent_points(latent_dim, n_batch) + ce
    y_overall = -ones((n_batch, 1))

    g1c1, g1r1, g2r2, g2c2 = train_step((X_image, X_text), y_overall)
    g1_loss.append(wasserstein_loss(y_overall, g1c1))
    g2_loss.append(wasserstein_loss(y_overall, g2c2))

    # summarize loss on this batch
    print('>%d, c11=%.3f, c12=%.3f, c21=%.3f, c22=%.3f, g1=%.3f, g2=%.3f' % (i+1, c11_hist[-1], c12_hist[-1], c21_hist[-1], c22_hist[-1], g1_loss[-1], g2_loss[-1]))
  # line plots of loss
  # plot_history(c1_hist, c2_hist, g_hist)

# load image data
image_dataset, text_dataset, CE = load_real_samples()
# train model
train(overall_model.g1, overall_model.g2, overall_model.c1, overall_model.c2, overall_model.r1, overall_model.r2, 
      image_dataset, text_dataset, CE, latent_dim)

>1, c11=-1.583, c12=4.456, c21=-1.157, c22=4.336, g1=2.233, g2=2.497
>2, c11=-74.667, c12=68.157, c21=-69.447, c22=122.518, g1=-52.122, g2=-93.090
>3, c11=-140.961, c12=160.923, c21=-122.569, c22=361.266, g1=-130.202, g2=-288.309
>4, c11=-204.341, c12=294.038, c21=-169.601, c22=777.966, g1=-252.758, g2=-646.075
>5, c11=-266.157, c12=472.868, c21=-213.421, c22=1415.161, g1=-415.138, g2=-1211.467
>6, c11=-327.043, c12=699.173, c21=-255.197, c22=2268.455, g1=-635.805, g2=-2001.619


KeyboardInterrupt: ignored