<a href="https://colab.research.google.com/github/CleanPegasus/Aging-c-GAN/blob/master/Aging_cGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!mkdir data

In [0]:
mkdir results

In [3]:
pwd

'/content'

In [0]:
cd data

/content/data


In [0]:
ls

In [0]:
!wget https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/static/wiki_crop.tar

--2019-05-27 12:44:13--  https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/static/wiki_crop.tar
Resolving data.vision.ee.ethz.ch (data.vision.ee.ethz.ch)... 129.132.52.162
Connecting to data.vision.ee.ethz.ch (data.vision.ee.ethz.ch)|129.132.52.162|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 811315200 (774M) [application/x-tar]
Saving to: ‘wiki_crop.tar’


2019-05-27 12:48:43 (2.87 MB/s) - ‘wiki_crop.tar’ saved [811315200/811315200]



In [0]:
!tar -xvf wiki_crop.tar

wiki_crop/00/23804200_1950-03-31_2013.jpg
wiki_crop/00/23836900_1988-10-31_2012.jpg
wiki_crop/00/23882000_1940-06-28_2009.jpg
wiki_crop/00/3386200_1955-02-21_1994.jpg
wiki_crop/00/43806600_1991-07-27_2014.jpg
wiki_crop/00/24817400_1984-10-25_2011.jpg
wiki_crop/00/34807600_1988-05-11_2013.jpg
wiki_crop/00/3482000_1971-03-20_2005.jpg
wiki_crop/00/2581400_1954-04-17_2007.jpg
wiki_crop/00/35833400_1966-04-06_2013.jpg
wiki_crop/00/6587200_1982-04-01_2009.jpg
wiki_crop/00/1687000_1932-12-12_1962.jpg
wiki_crop/00/26854400_1958-12-08_1987.jpg
wiki_crop/00/36883800_1990-01-22_2014.jpg
wiki_crop/00/36895700_1970-11-09_1985.jpg
wiki_crop/00/36898900_1977-12-23_2009.jpg
wiki_crop/00/868000_1977-04-25_2009.jpg
wiki_crop/00/37805500_1989-12-14_2011.jpg
wiki_crop/00/6789900_1956-02-03_1978.jpg
wiki_crop/00/878700_1920-05-20_2001.jpg
wiki_crop/00/9781600_1965-02-15_2008.jpg
wiki_crop/00/18814200_1983-08-29_2005.jpg
wiki_crop/00/18890000_1975-04-01_2009.jpg
wiki_crop/00/18890700_1918-10-07_1974.jpg
wik

In [0]:
import math
import os
import time
from datetime import datetime
from scipy.io import loadmat

In [0]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from keras import Input, Model
from keras.applications import InceptionResNetV2
from keras.callbacks import TensorBoard
from keras.layers import Conv2D, Flatten, Dense, BatchNormalization, Reshape, concatenate, LeakyReLU, Lambda, K, Conv2DTranspose, Activation, UpSampling2D, Dropout
from keras.optimizers import Adam
from keras.utils import to_categorical
from keras_preprocessing import image
from keras.models import load_model

In [0]:
def load_data(wiki_dir, dataset= 'wiki'):
  
  meta = loadmat(os.path.join(wiki_dir, "{}.mat".format(dataset)))
  
  full_path = meta[dataset][0, 0]["full_path"][0]
  
  dob = meta[dataset][0, 0]["dob"][0]
  
  photo_taken = meta[dataset][0, 0]["photo_taken"][0]
  
  age = [calculate_age(photo_taken[i], dob[i]) for i in range(len(dob))]
  
  images = []
  age_list = []
  
  for index, image_path in enumerate(full_path):
    
    images.append(image_path[0])
    age_list.append(age[index])
    
  return images, age_list

In [0]:
def calculate_age(taken, dob):
  birth = datetime.fromordinal(max(int(dob) - 366, 1))
  
  if birth.month < 7:
    return taken - birth.year
  else:
    return taken - birth.year - 1

In [0]:
def build_encoder():
  
  input_layer = Input(shape = (64, 64, 3))
  
  enc = Conv2D(filters = 32, kernel_size = 5, strides = 2, padding = 'same')(input_layer)
  enc = LeakyReLU(alpha = 0.2)(enc)
  
  enc = Conv2D(filters = 64, kernel_size = 5, strides = 2, padding = 'same')(enc)
  enc = BatchNormalization()(enc)
  enc = LeakyReLU(alpha = 0.2)(enc)
  
  enc = Conv2D(filters = 128, kernel_size = 5, strides = 2, padding = 'same')(enc)
  enc = BatchNormalization()(enc)
  enc = LeakyReLU(alpha = 0.2)(enc)
  
  enc = Conv2D(filters = 256, kernel_size = 5, strides = 2, padding = 'same')(enc)
  enc = BatchNormalization()(enc)
  enc = LeakyReLU(alpha = 0.2)(enc)
  
  enc = Flatten()(enc)
  
  enc = Dense(4096)(enc)
  enc = BatchNormalization()(enc)
  enc = LeakyReLU(alpha = 0.2)(enc)
  
  enc = Dense(100)(enc)
  
  model = Model(inputs = [input_layer], outputs = [enc])
  
  model.summary()
  
  return model

In [0]:
def build_gen():
  
  latent_dims = 100
  num_classes = 6
  
  input_z_noise = Input(shape = (latent_dims,))
  input_label = Input(shape = (num_classes,))
  
  x = concatenate([input_z_noise, input_label])
  
  x = Dense(2048, input_dim = latent_dims + num_classes)(x)
  x = LeakyReLU(alpha = 0.2)(x)
  x = Dropout(0.2)(x)
  
  x = Dense(256 * 8 * 8)(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(alpha = 0.2)(x)
  x = Dropout(0.2)(x)
  
  x = Reshape((8, 8, 256))(x)
  
  x = UpSampling2D(size = (2, 2))(x)
  x = Conv2D(filters = 128, kernel_size = 5, padding = 'same')(x)
  x = BatchNormalization(momentum = 0.8)(x)
  x = LeakyReLU(alpha = 0.2)(x)
  
  x = UpSampling2D(size = (2, 2))(x)
  x = Conv2D(filters = 64, kernel_size = 5, padding = 'same')(x)
  x = BatchNormalization(momentum = 0.8)(x)
  x = LeakyReLU(alpha = 0.2)(x)
  
  x = UpSampling2D(size = (2, 2))(x)
  x = Conv2D(filters = 3, kernel_size = 5, padding = 'same')(x)
  x = BatchNormalization(momentum = 0.8)(x)
  x = Activation('tanh')(x)
  
  model = Model(inputs = [input_z_noise, input_label], outputs = [x])
  
  model.summary()
  
  return model

In [0]:
def expand_label_input(x):
  x = K.expand_dims(x, axis = 1)
  x = K.expand_dims(x, axis=1)
  x = K.tile(x, [1, 32, 32, 1])
  return x

In [0]:
def build_disc():
  
  input_shape = (64, 64, 3)
  label_shape = (6,)
  
  image_input = Input(shape = input_shape)
  label_input = Input(shape = label_shape)
  
  x = Conv2D(filters = 64, kernel_size = 3, strides = 2, padding = 'same')(image_input)
  x = LeakyReLU(alpha = 0.2)(x)
  
  label_input1 = Lambda(expand_label_input)(label_input)
  
  x = concatenate([x, label_input1], axis = 3)
  
  x = Conv2D(128, kernel_size = 3, strides = 2, padding = 'same')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(alpha = 0.2)(x)
  
  x = Conv2D(256, kernel_size = 3, strides = 2, padding = 'same')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(alpha = 0.2)(x)
  
  x = Conv2D(512, kernel_size = 3, strides = 2, padding = 'same')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(alpha = 0.2)(x)
  
  x = Flatten()(x)
  x = Dense(1, activation = 'sigmoid')(x)
  
  model = Model(inputs = [image_input, label_input], outputs = [x])
  
  model.summary()
  
  return model

In [0]:
def age_to_category(age_list):
  
  age_list1 = []
  
  for age in age_list:
    
    if 0 < age <= 18:
      age_category = 0
    elif 18 < age <= 29:
      age_category = 1
    elif 29 < age <= 39:
      age_category = 2
    elif 39 < age <= 49:
      age_category = 3
    elif 49 < age <= 59:
      age_category = 4
    elif age >= 60:
      age_category = 5
      
    age_list1.append(age_category)
    
  return age_list1

In [0]:
def load_images(data_dir, image_paths, image_shape):
  
  images = []
  
  print("Loading Images")
  
  for i, image_path in enumerate(image_paths):
    
    loaded_image = image.load_img(os.path.join(data_dir, image_path), target_size = image_shape)
    loaded_image = image.img_to_array(loaded_image)
    #loaded_image = np.expand_dims(loaded_image, axis = 0)
    
    images.append(loaded_image)
    #if images is None:
    #  images = loaded_image
    #else:
     # images = np.concatenate([images, loaded_image], axis = 0)
     # print(images.shape)
  images = np.asarray(images)
  print(images.shape)
  print("Images Loaded")
  return images

In [0]:
def save_rgb_img(img, path):
  
  fig = plt.figure()
  ax = fig.add_subplot(1, 1, 1)
  ax.imshow(img)
  ax.axis("off")
  ax.set_title("Image")
  
  plt.savefig(path)
  plt.close

In [0]:
def write_log(callback, name, value, batch_no):
    summary = tf.Summary()
    summary_value = summary.value.add()
    summary_value.simple_value = value
    summary_value.tag = name
    callback.writer.add_summary(summary, batch_no)
    callback.writer.flush()

In [0]:
cd ..

/content


In [0]:
def train_gan():
  
  data_dir = "/content/data/"
  wiki_dir = os.path.join(data_dir, "wiki_crop")
  epochs = 500
  batch_size = 128
  image_shape = (64, 64, 3)
  z_shape = 100
  TRAIN_GAN = True
  TRAIN_ENCODER = False
  TRAIN_GAN_WITH_FR = False
  fr_image_shape = (192, 192, 3)
  
  disc_optimizer = Adam(lr = 0.0002, beta_1 = 0.5, beta_2 = 0.999, epsilon = 10e-8)
  gen_optimizer = Adam(lr = 0.0002, beta_1 = 0.5, beta_2 = 0.999, epsilon = 10e-8)
  adversarial_optimizer = Adam(lr = 0.0002, beta_1 = 0.5, beta_2 = 0.999, epsilon = 10e-8)
  
  if(os.path.isfile('generator.h5')):
    generator = load_model('generator.h5')
    print("Generator model loaded")
  else:
    generator = build_gen()
    generator.compile(loss = 'binary_crossentropy', optimizer = gen_optimizer)
    
  if(os.path.isfile('discriminator.h5')):
    discriminator = load_model('discriminator.h5')
    print("Discriminator model loaded")
  else:
    discriminator = build_disc()
    discriminator.compile(loss = 'binary_crossentropy', optimizer = disc_optimizer)
   
  
  discriminator.trainable = False
  
  input_z_noise = Input(shape = (100,))
  input_label = Input(shape = (6,))
  recons_image = generator([input_z_noise, input_label])
  valid = discriminator([recons_image, input_label])
  
  adversarial_model = Model(inputs = [input_z_noise, input_label], outputs = [valid])
  adversarial_model.compile(loss = ['binary_crossentropy'], optimizer = adversarial_optimizer)
  
  tensorboard = TensorBoard(log_dir = "logs/{}".format(time.time()))
  tensorboard.set_model(generator)
  tensorboard.set_model(discriminator)
  
  images, age_list = load_data(wiki_dir = wiki_dir, dataset = "wiki")
  #print(len(images))
  
  
  age_cat = age_to_category(age_list)
  
  final_age_cat = np.reshape(np.array(age_cat), [len(age_cat), 1])
  classes = len(set(age_cat))
  y = to_categorical(final_age_cat, num_classes = len(set(age_cat)))
  
  loaded_images = load_images(wiki_dir, images, (image_shape[0], image_shape[1]))
  print(len(loaded_images))
  
  real_labels = np.ones((batch_size, 1), dtype = np.float32)*0.9
  fake_labels = np.zeros((batch_size, 1), dtype = np.float32)*0.1
  
  for epoch in range(epochs):
    
    print("Epoch: {}".format(epoch))
    
    gen_losses = []
    disc_losses = []
    
    number_of_batches = int(len(loaded_images) / batch_size)
    
    for index in range(number_of_batches):
      print("Batch: {}".format(index + 1))
      
      images_batch = loaded_images[index * batch_size:(index + 1) * batch_size]
      images_batch = images_batch/ 127.5 - 1.0
      images_batch = images_batch.astype(np.float32)
                                        
      y_batch = y[index * batch_size : (index + 1) * batch_size]
      
      z_noise = np.random.normal(0, 1, size = (batch_size, z_shape))
      
      initial_recon_images = generator.predict_on_batch([z_noise, y_batch])
      
      disc_loss_real = discriminator.train_on_batch([images_batch, y_batch], real_labels)
      disc_loss_fake = discriminator.train_on_batch([initial_recon_images, y_batch], fake_labels)
      
      z_noise2 = np.random.normal(0, 1, size = (batch_size, z_shape))
      random_labels = np.random.randint(0, 6, batch_size).reshape(-1, 1)
      
      random_labels = to_categorical(random_labels, 6)
      
      gen_loss = adversarial_model.train_on_batch([z_noise2, random_labels], [1] * batch_size)
      
      disc_loss = 0.5 * np.add(disc_loss_real, disc_loss_fake)
      print("disc_loss: {}".format(disc_loss))
      print("gen_loss: {}".format(gen_loss))
      gen_losses.append(gen_loss)
      disc_losses.append(disc_loss)
      
      write_log(tensorboard, 'gen_loss', np.mean(gen_losses), epoch)
      write_log(tensorboard, 'disc_loss', np.mean(disc_losses), epoch)
      
    if epoch % 10 == 0:
        
      images_batch = loaded_images[0:batch_size]
      images_batch = images_batch / 127.5 - 1.0
      images_batch = images_batch.astype(np.float32)
        
      y_batch = y[0:batch_size]
      z_noise = np.random.normal(0, 1, size = (batch_size, z_shape))
        
      gen_images = generator.predict_on_batch([z_noise, y_batch])
        
      for i , image in enumerate(gen_images[:5]):
        save_rgb_img(image, path = "results/img_{}_{}.png".format(epoch, i))
          
      generator.save("generator.h5")
      discriminator.save("discriminator.h5")
      
      generator.save_weights("generator_weights.h5")
      discriminator.save_weights("discriminator_weights.h5")

In [0]:
def euclidean_distance_loss(y_true, y_pred):
  
  return K.sqrt(K.sum(K.square(y_pred - y_true), axis = 0))

In [0]:
train_gan()

Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 100)          0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            (None, 6)            0                                            
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 106)          0           input_1[0][0]                    
                                                                 input_2[0][0]                    
_____________________

  'Discrepancy between trainable weights and collected trainable'


disc_loss: 1.7710773944854736
gen_loss: 0.5672273635864258
Batch: 2
disc_loss: 0.7980452179908752
gen_loss: 1.2903497219085693
Batch: 3
disc_loss: 0.6451823711395264
gen_loss: 1.459495186805725
Batch: 4
disc_loss: 0.5589893460273743
gen_loss: 1.351181149482727
Batch: 5
disc_loss: 0.5628345012664795
gen_loss: 1.7605714797973633
Batch: 6
disc_loss: 0.6499626040458679
gen_loss: 2.348104476928711
Batch: 7
disc_loss: 0.5430289506912231
gen_loss: 1.2223234176635742
Batch: 8
disc_loss: 0.4432947039604187
gen_loss: 0.5130778551101685
Batch: 9
disc_loss: 0.39026448130607605
gen_loss: 0.64821857213974
Batch: 10
disc_loss: 0.37083691358566284
gen_loss: 1.684643268585205
Batch: 11
disc_loss: 0.4403573274612427
gen_loss: 3.594759225845337
Batch: 12
disc_loss: 0.27577877044677734
gen_loss: 1.8737359046936035
Batch: 13
disc_loss: 0.9882152080535889
gen_loss: 4.668109893798828
Batch: 14
disc_loss: 0.4760900139808655
gen_loss: 3.5107522010803223
Batch: 15
disc_loss: 0.6085302829742432
gen_loss: 1.93640

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


disc_loss: 1.3735499382019043
gen_loss: 1.0431760549545288


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch: 1
Batch: 1
disc_loss: 0.6237996816635132
gen_loss: 1.290212869644165
Batch: 2
disc_loss: 0.2503295838832855
gen_loss: 1.5671608448028564
Batch: 3
disc_loss: 0.19176043570041656
gen_loss: 1.785330057144165
Batch: 4
disc_loss: 0.21913127601146698
gen_loss: 0.8672573566436768
Batch: 5
disc_loss: 0.4354645013809204
gen_loss: 0.20354902744293213
Batch: 6
disc_loss: 0.24349431693553925
gen_loss: 0.470014750957489
Batch: 7
disc_loss: 0.20879927277565002
gen_loss: 0.9439008235931396
Batch: 8
disc_loss: 1.4020930528640747
gen_loss: 1.3610010147094727
Batch: 9
disc_loss: 0.28436341881752014
gen_loss: 4.526869773864746
Batch: 10
disc_loss: 0.6683381795883179
gen_loss: 1.5515464544296265
Batch: 11
disc_loss: 0.5850963592529297
gen_loss: 3.2803287506103516
Batch: 12
disc_loss: 0.9689387679100037
gen_loss: 2.6845479011535645
Batch: 13
disc_loss: 0.5795278549194336
gen_loss: 2.5766615867614746
Batch: 14
disc_loss: 0.3641818165779114
gen_loss: 2.6173481941223145
Batch: 15
disc_loss: 0.670834183

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


disc_loss: 0.19008482992649078
gen_loss: 4.385580062866211
Epoch: 11
Batch: 1
disc_loss: 0.18754686415195465
gen_loss: 4.447078704833984
Batch: 2
disc_loss: 0.2046523243188858
gen_loss: 3.914524555206299
Batch: 3
disc_loss: 0.6132732033729553
gen_loss: 2.9295778274536133
Batch: 4
disc_loss: 0.20170831680297852
gen_loss: 4.4742350578308105
Batch: 5
disc_loss: 0.23919136822223663
gen_loss: 3.6852505207061768
Batch: 6
disc_loss: 0.26762405037879944
gen_loss: 2.294950485229492
Batch: 7
disc_loss: 0.287872314453125
gen_loss: 3.4013185501098633
Batch: 8
disc_loss: 0.22746260464191437
gen_loss: 3.5138940811157227
Batch: 9
disc_loss: 0.22136884927749634
gen_loss: 3.7127625942230225
Batch: 10
disc_loss: 0.20247162878513336
gen_loss: 4.526558876037598
Batch: 11
disc_loss: 0.1870410144329071
gen_loss: 4.737001419067383
Batch: 12
disc_loss: 0.18516692519187927
gen_loss: 4.711408615112305
Batch: 13
disc_loss: 1.082385540008545
gen_loss: 3.976172924041748
Batch: 14
disc_loss: 0.28479963541030884
gen

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


disc_loss: 0.18380358815193176
gen_loss: 4.778287887573242
Epoch: 21
Batch: 1
disc_loss: 0.18982172012329102
gen_loss: 4.6282758712768555
Batch: 2
disc_loss: 0.1889122575521469
gen_loss: 4.504956245422363
Batch: 3
disc_loss: 0.4653010368347168
gen_loss: 4.319638252258301
Batch: 4
disc_loss: 0.1824832409620285
gen_loss: 5.700412750244141
Batch: 5
disc_loss: 0.1949896365404129
gen_loss: 5.336858749389648
Batch: 6
disc_loss: 0.4119771718978882
gen_loss: 3.3156731128692627
Batch: 7
disc_loss: 0.21667571365833282
gen_loss: 5.08674430847168
Batch: 8
disc_loss: 0.1877104640007019
gen_loss: 6.096464157104492
Batch: 9
disc_loss: 0.20651720464229584
gen_loss: 4.367743968963623
Batch: 10
disc_loss: 0.21221961081027985
gen_loss: 4.777295112609863
Batch: 11
disc_loss: 0.18440179526805878
gen_loss: 4.506147384643555
Batch: 12
disc_loss: 0.18932196497917175
gen_loss: 4.942234039306641
Batch: 13
disc_loss: 0.32330790162086487
gen_loss: 3.7906367778778076
Batch: 14
disc_loss: 0.20472797751426697
gen_lo

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


disc_loss: 0.17743636667728424
gen_loss: 6.248987197875977


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch: 31
Batch: 1
disc_loss: 0.1823984980583191
gen_loss: 5.451777935028076
Batch: 2
disc_loss: 0.18851371109485626
gen_loss: 4.579700469970703
Batch: 3
disc_loss: 0.254617303609848
gen_loss: 4.3652448654174805
Batch: 4
disc_loss: 0.18647654354572296
gen_loss: 5.214822292327881
Batch: 5
disc_loss: 0.18080458045005798
gen_loss: 5.986536979675293
Batch: 6
disc_loss: 0.21494822204113007
gen_loss: 3.733570098876953
Batch: 7
disc_loss: 0.24172675609588623
gen_loss: 5.403424263000488
Batch: 8
disc_loss: 0.17981690168380737
gen_loss: 6.803946495056152
Batch: 9
disc_loss: 0.21028684079647064
gen_loss: 4.007420539855957
Batch: 10
disc_loss: 0.210800901055336
gen_loss: 5.003836154937744
Batch: 11
disc_loss: 0.18839845061302185
gen_loss: 5.592023849487305
Batch: 12
disc_loss: 0.17277666926383972
gen_loss: 5.018877983093262
Batch: 13
disc_loss: 0.3088081181049347
gen_loss: 5.511256217956543
Batch: 14
disc_loss: 0.17850132286548615
gen_loss: 7.4453582763671875
Batch: 15
disc_loss: 0.21463832259178

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


disc_loss: 0.1702689379453659
gen_loss: 5.037568092346191
Epoch: 41
Batch: 1
disc_loss: 0.17147231101989746
gen_loss: 5.358397960662842
Batch: 2
disc_loss: 0.181526318192482
gen_loss: 4.511602401733398
Batch: 3
disc_loss: 0.44783222675323486
gen_loss: 7.426383018493652
Batch: 4
disc_loss: 0.21452903747558594
gen_loss: 8.67414665222168
Batch: 5
disc_loss: 0.20652449131011963
gen_loss: 6.24636173248291
Batch: 6
disc_loss: 0.20362791419029236
gen_loss: 4.460695266723633
Batch: 7
disc_loss: 0.20585455000400543
gen_loss: 4.976334571838379
Batch: 8
disc_loss: 0.19429510831832886
gen_loss: 7.422221660614014
Batch: 9
disc_loss: 0.18008996546268463
gen_loss: 6.226298809051514
Batch: 10
disc_loss: 0.19372524321079254
gen_loss: 5.749688148498535
Batch: 11
disc_loss: 0.17475251853466034
gen_loss: 5.901141166687012
Batch: 12
disc_loss: 0.1699080765247345
gen_loss: 6.357729911804199
Batch: 13
disc_loss: 1.4193147420883179
gen_loss: 12.749130249023438
Batch: 14
disc_loss: 1.0176595449447632
gen_loss:

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


disc_loss: 0.1693207174539566
gen_loss: 5.3936896324157715
Epoch: 51
Batch: 1
disc_loss: 0.16848528385162354
gen_loss: 4.652586460113525
Batch: 2
disc_loss: 0.1723324954509735
gen_loss: 5.712769508361816
Batch: 3
disc_loss: 0.18238623440265656
gen_loss: 4.370518684387207
Batch: 4
disc_loss: 0.17613011598587036
gen_loss: 4.717142105102539
Batch: 5
disc_loss: 0.17194229364395142
gen_loss: 4.775448322296143
Batch: 6
disc_loss: 0.18561166524887085
gen_loss: 5.016483306884766
Batch: 7
disc_loss: 0.2257772535085678
gen_loss: 3.869210720062256
Batch: 8
disc_loss: 0.19737376272678375
gen_loss: 4.500642776489258
Batch: 9
disc_loss: 0.17872396111488342
gen_loss: 5.219288349151611
Batch: 10
disc_loss: 0.16962450742721558
gen_loss: 5.7617974281311035
Batch: 11
disc_loss: 0.17059756815433502
gen_loss: 5.13431453704834
Batch: 12
disc_loss: 0.17194700241088867
gen_loss: 5.299771308898926
Batch: 13
disc_loss: 0.22323726117610931
gen_loss: 4.341366291046143
Batch: 14
disc_loss: 0.18638171255588531
gen_

In [0]:
!zip -r /content/results.zip /content/results

updating: content/results/ (stored 0%)
updating: content/results/img_0_2.png (deflated 8%)
updating: content/results/img_10_3.png (deflated 8%)
updating: content/results/img_20_3.png (deflated 13%)
updating: content/results/img_30_4.png (deflated 9%)
updating: content/results/img_0_1.png (deflated 8%)
updating: content/results/img_20_0.png (deflated 11%)
updating: content/results/img_0_4.png (deflated 8%)
updating: content/results/img_20_1.png (deflated 9%)
updating: content/results/img_30_3.png (deflated 44%)
updating: content/results/img_30_0.png (deflated 9%)
updating: content/results/img_20_4.png (deflated 9%)
updating: content/results/img_10_4.png (deflated 18%)
updating: content/results/img_20_2.png (deflated 9%)
updating: content/results/img_10_1.png (deflated 7%)
updating: content/results/img_30_1.png (deflated 9%)
updating: content/results/img_10_2.png (deflated 9%)
updating: content/results/img_30_2.png (deflated 8%)
updating: content/results/img_10_0.png (deflated 8%)
updati

In [0]:
def train_encoder():
  
  rpochs = 500
  encoder = build_encoder()
  encoder.compile(loss = euclidean_distance_loss, optimizer = 'adam')
  
  generator = load_model("generator.h5")
  
  z_i = np.random.normal(0, 1, size = (1000, z_shape))
  
  y = np.random.randint(low = 0, high = 6, size = (1000,), dtype = np.int64)
  num_classes = len(set(y))
  y = np.reshape(np.array(y), [len(y), 1])
  y = to_categorical(y, num_classes = num_classes)
  
  for epoch in range(epochs):
    
    print("Epoch: ", epoch)
    
    encoder_loss = []
    
    number_of_batches = int(z_i.shape[0]/batch_size)
    print("Number of batches: ", number_of_batches)
    
    for index in range(number_of_batches):
      
      print("Batch: ", index+1)
      
      z_batch = z_i[index*batch_size:(index+1)*batch_size]
      y_batch = y[index*batch_size:(index+1)*batch_size]
      
      generated_images = generator.predict_on_batch([z_batch, y_batch])
      
      encoder_loss = encoder.train_on_batch([generated_images, z_batch])
      
      write_log(tensorboard, "encoder_loss", np.mean(encoder_losses), epoch)
      
    #encoder.save_weights("encoder_weights.h5")
    encoder.save("encoder.h5")

In [0]:
train_encoder()

In [0]:
def build_img_resizer():
  
  input_layer = Input(shape = (64, 64, 3))
  
  resized_images = Lambda(lambda x: K.resize_images(x, height_factor = 3, width_factor = 3, data_format = 'channels_last'))(input_layer)
  
  model = Model(inputs = [input_layer], outputs = [resized_images])
  
  return model

In [0]:
def build_fr_model(input_shape):
    
    resent_model = InceptionResNetV2(include_top=False, weights='imagenet', input_shape=input_shape, pooling='avg')
    
    image_input = resent_model.input
    x = resent_model.layers[-1].output
    out = Dense(128)(x)
    
    embedder_model = Model(inputs=[image_input], outputs=[out])

    input_layer = Input(shape=input_shape)

    x = embedder_model(input_layer)
    output = Lambda(lambda x: K.l2_normalize(x, axis=-1))(x)

    model = Model(inputs=[input_layer], outputs=[output])
    
    return model

In [0]:
def lat_vec_opt():
  
  fr_image_shape = (192, 192, 3)
  epochs = 500
  
  encoder = load_model("encoder.h5")
  
  generator = load_model("generator.h5")
  
  image_resizer = build_img_resizer()
  image_resizer.compile(loss = ['binary_crossentropy', optimzer = "adam"])
  
  fr_model = build_fr_model(input_shape = fr_image_shape)
  fr_model.compile(loss = ['binary_crossentropy', optimizer = "adam"])
  
  fr_model.trainable = False
  
  input_image = Input(shape = (64, 64, 3))
  input_label = Input(shape = (6,))
  
  latent0 = encoder(input_image)
  gen_images = generator([latent0, input_label])
  
  resized_images = Lambda(lambda x: K.resize_images(gen_images, height_factor = 3, width_factor = 3, data_format = 'channels_last'))(gen_images)
  embeddings = fr_model(resized_images)
  
  fr_adversarial_model = Model(inputs = [input_image, input_label], outputs = [embeddings])
  
  optimizer = adversarial_optimizer
  
  fr_adversarial_model.compile(loss = euclidean_distance_loss, optimizer = optimizer)
  
  for epoch in range(epochs):
    
    print("Epoch: ", epoch)
    
    number_of_batches = int(len(loaded_images)/batch_size)
    print("Number of batches", number_of_batches)
    
    for index in range(number_of_batches):
      
      print("Batch: ", index+1)
      
      images_batch = loaded_images[index * batch_size:(index+1) * batch_size]
      
      images_batch = images_batch/255.0
      
      images_batch = images_batch.astype(float32)
      
      y_batch = y[index * batch_size:(index + 1) * batch_size]
      
      images_batch_resized = image_resizer.predict_on_batch(images_batch)
      real_embeddings = fr_model.predict_on_batch(images_batch_resized)
      
      reconstruction_loss = fr_adversarial_model.train_on_batch([images_batch, y_batch], real_embeddings)
      
      write_log(tensorboard, "reconstructio_loss", reconstruction_loss, index)
      
      generator.save_model("generator.h5")
      encoder.save_model("encoder.h5")
  
  
  

In [0]:
ls

[0m[01;34mdata[0m/             generator.h5  [01;34mresults[0m/     [01;34msample_data[0m/
discriminator.h5  [01;34mlogs[0m/         results.zip


In [0]:

from google.colab import files
files.download("results.zip")

In [0]:
files.download("generator.h5")

In [0]:
files.download("discriminator.h5")

In [0]:
files.download("encoder.h5")