<a href="https://colab.research.google.com/github/CleanPegasus/Aging-c-GAN/blob/master/Aging_cGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!mkdir data

In [0]:
mkdir results

In [0]:
from google.colab import files
files.upload()

In [0]:
pwd

In [0]:
cd data

In [0]:
ls

In [0]:
!wget https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/static/wiki_crop.tar

In [0]:
!tar -xvf wiki_crop.tar

In [0]:
import math
import os
import time
from datetime import datetime
from scipy.io import loadmat

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from keras import Input, Model
from keras.applications import InceptionResNetV2
from keras.callbacks import TensorBoard
from keras.layers import Conv2D, Flatten, Dense, BatchNormalization, Reshape, concatenate, LeakyReLU, Lambda, K, Conv2DTranspose, Activation, UpSampling2D, Dropout
from keras.optimizers import Adam
from keras.utils import to_categorical
from keras_preprocessing import image
from keras.models import load_model

Using TensorFlow backend.


In [0]:
def load_data(wiki_dir, dataset= 'wiki'):
  
  meta = loadmat(os.path.join(wiki_dir, "{}.mat".format(dataset)))
  
  full_path = meta[dataset][0, 0]["full_path"][0]
  
  dob = meta[dataset][0, 0]["dob"][0]
  
  photo_taken = meta[dataset][0, 0]["photo_taken"][0]
  
  age = [calculate_age(photo_taken[i], dob[i]) for i in range(len(dob))]
  
  images = []
  age_list = []
  
  for index, image_path in enumerate(full_path):
    
    images.append(image_path[0])
    age_list.append(age[index])
    
  return images, age_list

In [0]:
def calculate_age(taken, dob):
  birth = datetime.fromordinal(max(int(dob) - 366, 1))
  
  if birth.month < 7:
    return taken - birth.year
  else:
    return taken - birth.year - 1

In [0]:
def build_encoder():
  
  input_layer = Input(shape = (64, 64, 3))
  
  enc = Conv2D(filters = 32, kernel_size = 5, strides = 2, padding = 'same')(input_layer)
  enc = LeakyReLU(alpha = 0.2)(enc)
  
  enc = Conv2D(filters = 64, kernel_size = 5, strides = 2, padding = 'same')(enc)
  enc = BatchNormalization()(enc)
  enc = LeakyReLU(alpha = 0.2)(enc)
  
  enc = Conv2D(filters = 128, kernel_size = 5, strides = 2, padding = 'same')(enc)
  enc = BatchNormalization()(enc)
  enc = LeakyReLU(alpha = 0.2)(enc)
  
  enc = Conv2D(filters = 256, kernel_size = 5, strides = 2, padding = 'same')(enc)
  enc = BatchNormalization()(enc)
  enc = LeakyReLU(alpha = 0.2)(enc)
  
  enc = Flatten()(enc)
  
  enc = Dense(4096)(enc)
  enc = BatchNormalization()(enc)
  enc = LeakyReLU(alpha = 0.2)(enc)
  
  enc = Dense(100)(enc)
  
  model = Model(inputs = [input_layer], outputs = [enc])
  
  model.summary()
  
  return model

In [0]:
def build_gen():
  
  latent_dims = 100
  num_classes = 6
  
  input_z_noise = Input(shape = (latent_dims,))
  input_label = Input(shape = (num_classes,))
  
  x = concatenate([input_z_noise, input_label])
  
  x = Dense(2048, input_dim = latent_dims + num_classes)(x)
  x = LeakyReLU(alpha = 0.2)(x)
  x = Dropout(0.2)(x)
  
  x = Dense(256 * 8 * 8)(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(alpha = 0.2)(x)
  x = Dropout(0.2)(x)
  
  x = Reshape((8, 8, 256))(x)
  
  x = UpSampling2D(size = (2, 2))(x)
  x = Conv2D(filters = 128, kernel_size = 5, padding = 'same')(x)
  x = BatchNormalization(momentum = 0.8)(x)
  x = LeakyReLU(alpha = 0.2)(x)
  
  x = UpSampling2D(size = (2, 2))(x)
  x = Conv2D(filters = 64, kernel_size = 5, padding = 'same')(x)
  x = BatchNormalization(momentum = 0.8)(x)
  x = LeakyReLU(alpha = 0.2)(x)
  
  x = UpSampling2D(size = (2, 2))(x)
  x = Conv2D(filters = 3, kernel_size = 5, padding = 'same')(x)
  x = BatchNormalization(momentum = 0.8)(x)
  x = Activation('tanh')(x)
  
  model = Model(inputs = [input_z_noise, input_label], outputs = [x])
  
  model.summary()
  
  return model

In [0]:
def expand_label_input(x):
  x = K.expand_dims(x, axis = 1)
  x = K.expand_dims(x, axis=1)
  x = K.tile(x, [1, 32, 32, 1])
  return x

In [0]:
def build_disc():
  
  input_shape = (64, 64, 3)
  label_shape = (6,)
  
  image_input = Input(shape = input_shape)
  label_input = Input(shape = label_shape)
  
  x = Conv2D(filters = 64, kernel_size = 3, strides = 2, padding = 'same')(image_input)
  x = LeakyReLU(alpha = 0.2)(x)
  
  label_input1 = Lambda(expand_label_input)(label_input)
  
  x = concatenate([x, label_input1], axis = 3)
  
  x = Conv2D(128, kernel_size = 3, strides = 2, padding = 'same')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(alpha = 0.2)(x)
  
  x = Conv2D(256, kernel_size = 3, strides = 2, padding = 'same')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(alpha = 0.2)(x)
  
  x = Conv2D(512, kernel_size = 3, strides = 2, padding = 'same')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(alpha = 0.2)(x)
  
  x = Flatten()(x)
  x = Dense(1, activation = 'sigmoid')(x)
  
  model = Model(inputs = [image_input, label_input], outputs = [x])
  
  model.summary()
  
  return model

In [0]:
def age_to_category(age_list):
  
  age_list1 = []
  
  for age in age_list:
    
    if 0 < age <= 18:
      age_category = 0
    elif 18 < age <= 29:
      age_category = 1
    elif 29 < age <= 39:
      age_category = 2
    elif 39 < age <= 49:
      age_category = 3
    elif 49 < age <= 59:
      age_category = 4
    elif age >= 60:
      age_category = 5
      
    age_list1.append(age_category)
    
  return age_list1

In [0]:
def load_images(data_dir, image_paths, image_shape):
  
  images = []
  
  print("Loading Images")
  
  for i, image_path in enumerate(image_paths):
    
    loaded_image = image.load_img(os.path.join(data_dir, image_path), target_size = image_shape)
    loaded_image = image.img_to_array(loaded_image)
    #loaded_image = np.expand_dims(loaded_image, axis = 0)
    
    images.append(loaded_image)
    #if images is None:
    #  images = loaded_image
    #else:
     # images = np.concatenate([images, loaded_image], axis = 0)
     # print(images.shape)
  images = np.asarray(images)
  print(images.shape)
  print("Images Loaded")
  return images

In [0]:
def save_rgb_img(img, path):
  
  fig = plt.figure()
  ax = fig.add_subplot(1, 1, 1)
  ax.imshow(img)
  ax.axis("off")
  ax.set_title("Image")
  
  plt.savefig(path)
  plt.close

In [0]:
def write_log(callback, name, value, batch_no):
    summary = tf.Summary()
    summary_value = summary.value.add()
    summary_value.simple_value = value
    summary_value.tag = name
    callback.writer.add_summary(summary, batch_no)
    callback.writer.flush()

In [0]:
def train():
  
  data_dir = "/content/data/"
  wiki_dir = os.path.join(data_dir, "wiki_crop")
  epochs = 500
  batch_size = 128
  image_shape = (64, 64, 3)
  z_shape = 100
  TRAIN_GAN = True
  TRAIN_ENCODER = False
  TRAIN_GAN_WITH_FR = False
  fr_image_shape = (192, 192, 3)
  
  disc_optimizer = Adam(lr = 0.0002, beta_1 = 0.5, beta_2 = 0.999, epsilon = 10e-8)
  gen_optimizer = Adam(lr = 0.0002, beta_1 = 0.5, beta_2 = 0.999, epsilon = 10e-8)
  adversarial_optimizer = Adam(lr = 0.0002, beta_1 = 0.5, beta_2 = 0.999, epsilon = 10e-8)
  
  if(os.path.isfile('generator.h5')):
    generator = load_model('generator.h5')
    print("Generator model loaded")
  else:
    generator = build_gen()
    generator.compile(loss = 'binary_crossentropy', optimizer = gen_optimizer)
    
  if(os.path.isfile('discriminator.h5')):
    discriminator = load_model('discriminator.h5')
    print("Discriminator model loaded")
  else:
    discriminator = build_disc()
    discriminator.compile(loss = 'binary_crossentropy', optimizer = disc_optimizer)
   
  
  discriminator.trainable = False
  
  input_z_noise = Input(shape = (100,))
  input_label = Input(shape = (6,))
  recons_image = generator([input_z_noise, input_label])
  valid = discriminator([recons_image, input_label])
  
  adversarial_model = Model(inputs = [input_z_noise, input_label], outputs = [valid])
  adversarial_model.compile(loss = ['binary_crossentropy'], optimizer = adversarial_optimizer)
  
  tensorboard = TensorBoard(log_dir = "logs/{}".format(time.time()))
  tensorboard.set_model(generator)
  tensorboard.set_model(discriminator)
  
  images, age_list = load_data(wiki_dir = wiki_dir, dataset = "wiki")
  #print(len(images))
  
  
  age_cat = age_to_category(age_list)
  
  final_age_cat = np.reshape(np.array(age_cat), [len(age_cat), 1])
  classes = len(set(age_cat))
  y = to_categorical(final_age_cat, num_classes = len(set(age_cat)))
  
  loaded_images = load_images(wiki_dir, images, (image_shape[0], image_shape[1]))
  print(len(loaded_images))
  
  real_labels = np.ones((batch_size, 1), dtype = np.float32)*0.9
  fake_labels = np.zeros((batch_size, 1), dtype = np.float32)*0.1
  
  for epoch in range(epochs):
    
    print("Epoch: {}".format(epoch))
    
    gen_losses = []
    disc_losses = []
    
    number_of_batches = int(len(loaded_images) / batch_size)
    
    for index in range(number_of_batches):
      print("Batch: {}".format(index + 1))
      
      images_batch = loaded_images[index * batch_size:(index + 1) * batch_size]
      images_batch = images_batch/ 127.5 - 1.0
      images_batch = images_batch.astype(np.float32)
                                        
      y_batch = y[index * batch_size : (index + 1) * batch_size]
      
      z_noise = np.random.normal(0, 1, size = (batch_size, z_shape))
      
      initial_recon_images = generator.predict_on_batch([z_noise, y_batch])
      
      disc_loss_real = discriminator.train_on_batch([images_batch, y_batch], real_labels)
      disc_loss_fake = discriminator.train_on_batch([initial_recon_images, y_batch], fake_labels)
      
      z_noise2 = np.random.normal(0, 1, size = (batch_size, z_shape))
      random_labels = np.random.randint(0, 6, batch_size).reshape(-1, 1)
      
      random_labels = to_categorical(random_labels, 6)
      
      gen_loss = adversarial_model.train_on_batch([z_noise2, random_labels], [1] * batch_size)
      
      disc_loss = 0.5 * np.add(disc_loss_real, disc_loss_fake)
      print("disc_loss: {}".format(disc_loss))
      print("gen_loss: {}".format(gen_loss))
      gen_losses.append(gen_loss)
      disc_losses.append(disc_loss)
      
      write_log(tensorboard, 'gen_loss', np.mean(gen_losses), epoch)
      write_log(tensorboard, 'disc_loss', np.mean(disc_losses), epoch)
      
    if epoch % 10 == 0:
        
      images_batch = loaded_images[0:batch_size]
      images_batch = images_batch / 127.5 - 1.0
      images_batch = images_batch.astype(np.float32)
        
      y_batch = y[0:batch_size]
      z_noise = np.random.normal(0, 1, size = (batch_size, z_shape))
        
      gen_images = generator.predict_on_batch([z_noise, y_batch])
        
      for i , image in enumerate(gen_images[:5]):
        save_rgb_img(image, path = "results/img_{}_{}.png".format(epoch, i))
          
      generator.save("generator.h5")
      discriminator.save("discriminator.h5")

In [0]:
if __name__ == '__main__':
  train()

Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
Generator model loaded
Instructions for updating:
Use tf.cast instead.
Discriminator model loaded
Loading Images
(62328, 64, 64, 3)
Images Loaded
62328
Epoch: 0
Batch: 1


  'Discrepancy between trainable weights and collected trainable'


disc_loss: 0.16359633207321167
gen_loss: 3.6356282234191895
Batch: 2
disc_loss: 0.1670791655778885
gen_loss: 5.273960590362549
Batch: 3
disc_loss: 0.3782108426094055
gen_loss: 15.925105094909668
Batch: 4
disc_loss: 0.1689850240945816
gen_loss: 16.11809539794922
Batch: 5
disc_loss: 0.42580386996269226
gen_loss: 8.325092315673828
Batch: 6
disc_loss: 0.2234198898077011
gen_loss: 6.543938159942627
Batch: 7
disc_loss: 0.19114595651626587
gen_loss: 5.333652496337891
Batch: 8
disc_loss: 0.21656978130340576
gen_loss: 5.8662800788879395
Batch: 9
disc_loss: 0.19958043098449707
gen_loss: 6.360107421875
Batch: 10
disc_loss: 0.18609441816806793
gen_loss: 6.6468634605407715
Batch: 11
disc_loss: 0.17804716527462006
gen_loss: 5.956684112548828
Batch: 12
disc_loss: 0.16910633444786072
gen_loss: 5.891122341156006
Batch: 13
disc_loss: 0.563351035118103
gen_loss: 3.2383041381835938
Batch: 14
disc_loss: 0.24520467221736908
gen_loss: 4.613022327423096
Batch: 15
disc_loss: 0.19427284598350525
gen_loss: 5.260

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


disc_loss: 0.16385523974895477
gen_loss: 5.896026134490967


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch: 1
Batch: 1
disc_loss: 0.16318124532699585
gen_loss: 4.576414108276367
Batch: 2
disc_loss: 0.16308622062206268
gen_loss: 6.477211952209473
Batch: 3
disc_loss: 0.16548606753349304
gen_loss: 6.074612617492676
Batch: 4
disc_loss: 0.16440315544605255
gen_loss: 6.68824577331543
Batch: 5
disc_loss: 0.1662428081035614
gen_loss: 5.9410247802734375
Batch: 6
disc_loss: 0.18010655045509338
gen_loss: 5.776150226593018
Batch: 7
disc_loss: 0.1669531613588333
gen_loss: 4.064803600311279
Batch: 8
disc_loss: 0.1746930480003357
gen_loss: 4.884640216827393
Batch: 9
disc_loss: 0.16529235243797302
gen_loss: 5.499580383300781
Batch: 10
disc_loss: 0.16351741552352905
gen_loss: 6.995335578918457
Batch: 11
disc_loss: 0.16329166293144226
gen_loss: 5.420575141906738
Batch: 12
disc_loss: 0.16292482614517212
gen_loss: 4.484548568725586
Batch: 13
disc_loss: 0.17341123521327972
gen_loss: 4.903903961181641
Batch: 14
disc_loss: 0.1757175177335739
gen_loss: 2.696258544921875
Batch: 15
disc_loss: 0.182635858654975

In [0]:
!zip -r /content/results.zip /content/results

updating: content/results/ (stored 0%)
updating: content/results/img_60_0.png (deflated 13%)
updating: content/results/img_40_0.png (deflated 28%)
updating: content/results/img_0_2.png (deflated 46%)
updating: content/results/img_10_3.png (deflated 9%)
updating: content/results/img_20_3.png (deflated 9%)
updating: content/results/img_50_2.png (deflated 10%)
updating: content/results/img_30_4.png (deflated 7%)
updating: content/results/img_60_1.png (deflated 9%)
updating: content/results/img_40_4.png (deflated 9%)
updating: content/results/img_0_1.png (deflated 10%)
updating: content/results/img_20_0.png (deflated 9%)
updating: content/results/img_60_2.png (deflated 15%)
updating: content/results/img_0_4.png (deflated 7%)
updating: content/results/img_20_1.png (deflated 9%)
updating: content/results/img_30_3.png (deflated 7%)
updating: content/results/img_40_2.png (deflated 9%)
updating: content/results/img_30_0.png (deflated 7%)
updating: content/results/img_20_4.png (deflated 21%)
upd

In [0]:
ls

[0m[01;34mdata[0m/             generator.h5  [01;34mresults[0m/     result.zip
discriminator.h5  [01;34mlogs[0m/         results.zip  [01;34msample_data[0m/


In [0]:

files.download("results.zip")

In [0]:
files.download("generator.h5")

----------------------------------------
Exception happened during processing of request from ('::ffff:127.0.0.1', 56458, 0, 0)
Traceback (most recent call last):
  File "/usr/lib/python3.6/socketserver.py", line 317, in _handle_request_noblock
    self.process_request(request, client_address)
  File "/usr/lib/python3.6/socketserver.py", line 348, in process_request
    self.finish_request(request, client_address)
  File "/usr/lib/python3.6/socketserver.py", line 361, in finish_request
    self.RequestHandlerClass(request, client_address, self)
  File "/usr/lib/python3.6/socketserver.py", line 721, in __init__
    self.handle()
  File "/usr/lib/python3.6/http/server.py", line 418, in handle
    self.handle_one_request()
  File "/usr/lib/python3.6/http/server.py", line 406, in handle_one_request
    method()
  File "/usr/lib/python3.6/http/server.py", line 639, in do_GET
    self.copyfile(f, self.wfile)
  File "/usr/lib/python3.6/http/server.py", line 800, in copyfile
    shutil.copyfil

In [0]:
files.download("discriminator.h5")