In [73]:
import os 
import cv2
import numpy as np
from PIL import Image
import albumentations as alb 
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow import keras
from tqdm import tqdm 
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Conv2D, Conv2DTranspose, Dense, BatchNormalization,
    GlobalAveragePooling2D, MaxPooling2D, LeakyReLU,
    Dropout, Input, Reshape
)
from sklearn.model_selection import train_test_split

In [74]:
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
image_shape = (64, 64, 3)
augmentation = alb.Compose([
        alb.RandomCrop(55, 55),
        alb.Resize(image_shape[0], image_shape[1], always_apply=True),
        alb.RandomContrast(),
        alb.RandomBrightnessContrast(),
        alb.HueSaturationValue(),
        alb.CLAHE(),
        alb.JpegCompression(),
        alb.Normalize(mean, std, max_pixel_value=255.0, always_apply=True)  
    ])

In [75]:
def data_loader(batch_image_path):
    batch_images = []
    for path in batch_image_path:
        full_path = os.path.join(faces_path, path)
        img = np.array(Image.open(full_path).convert('RGB'))
        img = augmentation(image=img)['image']
        batch_images.append(img)
    batch_images = np.array(batch_images)
    return batch_images

In [76]:
def residual_block(x):
    a = x
    for i in range(3):
        x = Conv2D(x.shape[-1]*2, kernel_size=3, padding='same')(x)
        x = LeakyReLU()(x)
        x = Dropout(0.3)(x)
    a_ = Conv2D(x.shape[-1], kernel_size=1)(a)
    x = tf.math.add(a_, x )
    x = Conv2DTranspose(a.shape[-1], kernel_size=3, strides=2, padding='same')(x)
    x = LeakyReLU()(x)
    x = Dropout(0.3)(x)
    return x

In [77]:
def generator(inp_shape, image_shape):
    inp_ = Input(shape=(inp_shape))
    x = Dense(int(image_shape[0]/8)*int(image_shape[1]/8))(inp_)
    x = Reshape((int(image_shape[0]/8), int(image_shape[1]/8), 1))(x)
    x = Conv2DTranspose(32, kernel_size=3, strides=2, padding='same')(x)
    x = LeakyReLU()(x)
    x = Dropout(0.3)(x)
    for i in range(2):
        x = residual_block(x)
    x = Conv2D(3, kernel_size=3, padding='same', activation='tanh')(x)
    model = Model(inp_, x)
    return model

In [78]:
def inception_block(x):
    a = Conv2D(x.shape[-1], kernel_size=1, padding='same')(x)
    a = LeakyReLU()(a)
    
    b = Conv2D(x.shape[-1], kernel_size=1, padding='same')(x)
    b = LeakyReLU()(b)
    b = Conv2D(x.shape[-1], kernel_size=3, padding='same')(b)
    b = LeakyReLU()(b)
    
    c = Conv2D(x.shape[-1], kernel_size=1, padding='same')(x)
    c = LeakyReLU()(c)
    c = Conv2D(x.shape[-1], kernel_size=5, padding='same')(c)
    c = LeakyReLU()(c)
    
    d = MaxPooling2D(pool_size=3, strides=1, padding='same')(x)
    d = Conv2D(x.shape[-1], kernel_size=1, padding='same')(d)
    d = LeakyReLU()(d)
    
    x = tf.concat([a, b, c, d], axis=-1)
    return x

In [79]:
def discriminator(image_shape):
    inp_ = Input(shape=image_shape)
    x = Conv2D(32, kernel_size=3, padding='same', dilation_rate=2)(inp_)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    x = Dropout(0.3)(x)
    for i in range(2):
        x = inception_block(x)
        x = BatchNormalization()(x)
        x = LeakyReLU()(x)
        x = Dropout(0.3)(x)
    x = GlobalAveragePooling2D()(x)
    x = Dense(100)(x)
    x = LeakyReLU()(x)
    x = Dropout(0.5)(x)
    x = Dense(1, activation='sigmoid')(x)
    model = Model(inp_, x)
    return model

In [80]:
def gan(generator, discriminator):
    discriminator.trainable = False
    inp_ = Input(shape=(z_dims))
    x = generator(inp_)
    x = discriminator(x)
    model = Model(inp_, x)
    return model

In [81]:
z_dims = 100
image_shape = (64, 64, 3)

In [82]:
discriminator_model = discriminator(image_shape)
discriminator_model.compile(loss='binary_crossentropy', optimizer=Adam(lr=1e-4), metrics=['accuracy'])

generator_model = generator(z_dims, image_shape)

gan_model = gan(generator_model, discriminator_model)
gan_model.compile(loss='binary_crossentropy', optimizer=Adam(lr=1e-4), metrics=['accuracy'])

In [83]:
def image_generation():
    z = np.random.normal(0, 1, (7*7, 100))
    images = generator_model.predict(z)
    images = images*0.5 + 0.5
    fig, axis = plt.subplots(7, 7, figsize=(10, 10))
    
    num = 0
    for i in range(7):
        for j in range(7):
            axis[i, j].imshow(images[num, :, :, :])
            num += 1
    plt.show()

In [84]:
def train(batch_size, epochs):
    for epoch in range(epochs):
        steps = len(image_path)//batch_size
        for step in tqdm(range(steps)):
            if (step+1)*batch_size > len(image_path):
                batch_image_path = image_path[step*batch_size :]
            else:
                batch_image_path = image_path[step*batch_size: (step+1)*batch_size]
            batch_images = data_loader(batch_image_path)
            z = np.random.normal(0, 1, (batch_images.shape[0], 100))
            
            real_labels = np.ones((batch_images.shape[0], 1))*0.9
            fake_labels = np.ones((batch_images.shape[0], 1))*0.0
            
            generated_image = generator_model.predict(z)
            
            real_images_loss, accuracy_real = discriminator_model.train_on_batch(batch_images, real_labels)
            fake_images_loss, accuracy_fake = discriminator_model.train_on_batch(generated_image, fake_labels)
            
            loss = (real_images_loss + fake_images_loss)/2
            accuracy = (accuracy_real + accuracy_fake)/2
            
            z = np.random.normal(0, 1, (batch_images.shape[0], 100))
            
            gan_loss, gan_accuracy = gan_model.train_on_batch(z, real_labels)
        np.random.shuffle(image_path)
            
        print(f'EPOCH {epoch} COMPLETE |  DISCRIMINATOR-LOSS = {loss} | DISCRIMINATOR-ACC {accuracy} | GAN-LOSS {gan_loss}' )
        image_generation()

In [85]:
epochs = 100
batch_size = 32
faces_path = 'img_align_celeba'
image_path = os.listdir(faces_path)

In [86]:
train(batch_size, epochs)

  0%|          | 4/6331 [02:03<54:05:57, 30.78s/it]


KeyboardInterrupt: 