In [1]:
import os 
import cv2
import numpy as np
from PIL import Image
import albumentations as alb 
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow import keras
from tqdm import tqdm 
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Conv2D, Conv2DTranspose, Dense, BatchNormalization,
    GlobalAveragePooling2D, MaxPooling2D, LeakyReLU,
    Dropout, Input, Reshape, Conv1D, ReLU
)
from sklearn.model_selection import train_test_split

In [2]:
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
image_shape = (64, 64, 3)
augmentation = alb.Compose([
        alb.CenterCrop(160, 160),
        alb.Resize(image_shape[0], image_shape[1], always_apply=True),
        alb.Normalize(mean, std, always_apply=True)  
    ])

In [3]:
def data_loader(batch_image_path):
    batch_images = []
    for path in batch_image_path:
        full_path = os.path.join(faces_path, path)
        img = np.array(Image.open(full_path).convert('RGB'))
        img = augmentation(image=img)['image']
        batch_images.append(img)
    batch_images = np.array(batch_images)
    return batch_images

In [53]:
#https://github.com/thisisiron/spectral_normalization-tf2
class SpectralNormalization(tf.keras.layers.Wrapper):
    def __init__(self, layer, iteration=1, eps=1e-12, training=True, **kwargs):
        self.iteration = iteration
        self.eps = eps
        self.do_power_iteration = training
        if not isinstance(layer, tf.keras.layers.Layer):
            raise ValueError(
                'Please initialize `TimeDistributed` layer with a '
                '`Layer` instance. You passed: {input}'.format(input=layer))
        super(SpectralNormalization, self).__init__(layer, **kwargs)

    def build(self, input_shape):
        self.layer.build(input_shape)

        self.w = self.layer.kernel
        self.w_shape = self.w.shape.as_list()

        self.v = self.add_weight(shape=(1, self.w_shape[0] * self.w_shape[1] * self.w_shape[2]),
                                 initializer=tf.initializers.TruncatedNormal(stddev=0.02),
                                 trainable=False,
                                 name='sn_v',
                                 dtype=tf.float32)

        self.u = self.add_weight(shape=(1, self.w_shape[-1]),
                                 initializer=tf.initializers.TruncatedNormal(stddev=0.02),
                                 trainable=False,
                                 name='sn_u',
                                 dtype=tf.float32)

        super(SpectralNormalization, self).build()

    def call(self, inputs):
        self.update_weights()
        output = self.layer(inputs)
        self.restore_weights()  # Restore weights because of this formula "W = W - alpha * W_SN`"
        return output
    
    def update_weights(self):
        w_reshaped = tf.reshape(self.w, [-1, self.w_shape[-1]])
        
        u_hat = self.u
        v_hat = self.v  # init v vector

        if self.do_power_iteration:
            for _ in range(self.iteration):
                v_ = tf.matmul(u_hat, tf.transpose(w_reshaped))
                v_hat = v_ / (tf.reduce_sum(v_**2)**0.5 + self.eps)

                u_ = tf.matmul(v_hat, w_reshaped)
                u_hat = u_ / (tf.reduce_sum(u_**2)**0.5 + self.eps)

        sigma = tf.matmul(tf.matmul(v_hat, w_reshaped), tf.transpose(u_hat))
        self.u.assign(u_hat)
        self.v.assign(v_hat)

        self.layer.kernel.assign(self.w / sigma)

    def restore_weights(self):
        self.layer.kernel.assign(self.w)

In [47]:
def self_attention_block(x):
    batch, height, width, channel = x.shape
    
    key = Conv2D(x.shape[-1], kernel_size=1)(x)
    query = Conv2D(x.shape[-1], kernel_size=1)(x)
    value = Conv2D(x.shape[-1], kernel_size=1)(x)
    
    key = tf.transpose(x, perm=[0, 3, 1, 2])
    query = tf.transpose(x, perm=[0, 3, 1, 2])
    value = tf.transpose(x, perm=[0, 3, 1, 2])
    
    score = tf.matmul(query, key, transpose_b=True)
    score = tf.nn.softmax(score, axis=-1)
    
    out = tf.matmul(score, value, transpose_b=True)
    out = tf.transpose(x, perm=[0, 2, 1, 3])
    
    out  = tf.Variable(0.5, trainable=True)*out
    return out


In [65]:
def upsample_block(x):
    x = SpectralNormalization(Conv2DTranspose(x.shape[-1]*4, kernel_size=4, strides=2, padding='same'))(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    return x

In [66]:
def generator(z_dims):
    inp_ = Input(shape=(z_dims))
    x = Dense(2*2*2)(inp_)
    x = Reshape((2, 2, 2))(x)
    for i in range(3):
        x = upsample_block(x)
    x = self_attention_block(x)
    x = upsample_block(x)
    x = self_attention_block(x)
    x = Conv2DTranspose(3, kernel_size=4, strides=2, padding='same', activation='tanh')(x)
    model = Model(inp_, x)
    return model

In [67]:
def downsample_block(x):
    x = SpectralNormalization(Conv2D(x.shape[-1]*4, kernel_size=4, strides=2, padding='same'))(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    return x

In [68]:
def discriminator(img_shape):
    inp_ = Input(shape=img_shape)
    x = inp_
    for i in range(3):
        x = downsample_block(x)
    x = self_attention_block(x)
    x = downsample_block(x)
    x = self_attention_block(x)
    x = Conv2D(1, kernel_size=4)(x)
    x = tf.reshape(x, (-1,1))
    x = tf.nn.softmax(x, axis=-1)
    model = Model(inp_, x)
    return model

In [69]:
def gan(generator, discriminator):
    inp_ = Input(shape=(z_dims))
    x = generator(inp_)
#     discriminator.trainable = False
    x = discriminator(x)
    model = Model(inp_, x)
    return model

In [70]:
z_dims = 100
image_shape = (64, 64, 3)

In [73]:
discriminator_model = discriminator(image_shape)
discriminator_model.compile(loss='hinge', optimizer=Adam(lr=4e-4, beta_1=0.0, beta_2=0.9), metrics=['accuracy'])

generator_model = generator(z_dims)
gan_model = gan(generator_model, discriminator_model)
gan_model.compile(loss='hinge', optimizer=Adam(lr=1e-4, beta_1=0.0, beta_2=0.9), metrics=['accuracy'])

In [89]:
def image_generation():
    z = np.random.normal(0, 1, (7*7, 100))
    images = generator_model.predict(z)
    images = images*0.5 + 0.5
    fig, axis = plt.subplots(7, 7, figsize=(10, 10))
    
    num = 0
    for i in range(7):
        for j in range(7):
            axis[i, j].imshow(images[num, :, :, :])
            num += 1
    plt.show()

In [90]:
def train(batch_size, epochs):
    for epoch in range(epochs):
        steps = len(image_path)//batch_size
        for step in tqdm(range(steps)):
            if (step+1)*batch_size > len(image_path):
                batch_image_path = image_path[step*batch_size :]
            else:
                batch_image_path = image_path[step*batch_size: (step+1)*batch_size]
            batch_images = data_loader(batch_image_path)
            z = np.random.normal(0, 1, (batch_images.shape[0], 100))
            
            real_labels = np.ones((batch_images.shape[0], 1))*0.9
            fake_labels = np.ones((batch_images.shape[0], 1))*0.0
            
            generated_image = generator_model.predict(z)
            
#             discriminator_model.trainable = True

            real_images_loss, accuracy_real = discriminator_model.train_on_batch(batch_images, real_labels)
            fake_images_loss, accuracy_fake = discriminator_model.train_on_batch(generated_image, fake_labels)
            
            loss = (real_images_loss + fake_images_loss)/2
            accuracy = (accuracy_real + accuracy_fake)/2
            
            z = np.random.normal(0, 1, (batch_images.shape[0], 100))
            
#             discriminator_model.trainable = False
            gan_loss, gan_accuracy = gan_model.train_on_batch(z, real_labels)
            
        np.random.shuffle(image_path)
            
        print(f'EPOCH {epoch} COMPLETE |  DISCRIMINATOR-LOSS = {loss} | DISCRIMINATOR-ACC {accuracy} | GAN-LOSS {gan_loss}' )
        image_generation()

In [91]:
faces_path = 'img_align_celeba'
image_path = os.listdir(faces_path)

In [92]:
epochs = 100
batch_size = 16

In [None]:
train(batch_size, epochs)