In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import tensorflow as tf
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
tpu_strategy = tf.distribute.experimental.TPUStrategy(tpu)

# Global Settings

## Settings

In [1]:
# Basic Settings
raw_path = r"/content/drive/MyDrive/celeba/img_align_celeba/"
test_raw_path = r"/content/drive/MyDrive/celeba/test/"
cascade_path = r"/content/drive/MyDrive/celeba/cascade/"
data_path = r"/content/drive/MyDrive/celeba/data.pkl"
model_path = r"/content/drive/MyDrive/face generation (fl)/model/"
output_path = r"/content/drive/MyDrive/face generation (fl)/output/"
eye_size = 32
face_size = 128
epoch_size = 7500
batch_size = 64
# Optimizer
optimizer_name = "Adam"
from tensorflow.keras.optimizers import Adam as optimizer

## Functions

In [2]:
def normalize(img):
    return cv2.normalize(img.astype("float32"),None,0.0,1.0,cv2.NORM_MINMAX)
def restore(img_norm):
    return (img_norm * 255).astype("uint8")
def crop(raw_file_path):
    eye_list = []
    face_list = []
    raw = cv2.imread(raw_file_path)
    face = cv2.CascadeClassifier(cascade_path + "face_cascade.xml").detectMultiScale(cv2.cvtColor(raw,cv2.COLOR_BGR2GRAY))
    if len(face) == 1:
        face = raw[face[0][1]:face[0][1]+face[0][3],face[0][0]:face[0][0]+face[0][2]]
        face_processed = normalize(cv2.resize(face,(face_size,face_size)))
        for i in range(2):
            face_list.append(face_processed)
    else:
        raise RuntimeError("A singular face cannot be detected")
    eyes = cv2.CascadeClassifier(cascade_path + "eye_cascade.xml").detectMultiScale(cv2.cvtColor(face,cv2.COLOR_BGR2GRAY))
    if(len(eyes) == 2):
        for i in range(2):
            eye_list.append(normalize(cv2.resize(face[eyes[i][1]:eyes[i][1]+eyes[i][3],eyes[i][0]:eyes[i][0]+eyes[i][2]],(eye_size,eye_size))))
    else:
        raise RuntimeError("The number of eyes detected is not 2")
    return [eye_list,face_list]

# Data Handling

## Make Data File

In [None]:
import os, cv2, pickle
import numpy as np
from tqdm import tqdm
data = [[],[]]
data_n = 0
with tqdm(total = epoch_size) as phar:
    for raw_file_name in os.listdir(raw_path):
        try:
            data_item = crop(raw_path + raw_file_name)
            data[0] += data_item[0]
            data[1] += data_item[1]
        except:
            continue
        data_n += 1
        phar.update(1)
        if data_n == epoch_size:
            break
with open(data_path,"wb") as data_raw:
    pickle.dump([np.array(data[0]),np.array(data[1])],data_raw)

## Load Data File

In [3]:
import pickle
with open(data_path,"rb") as data_raw:
    data = pickle.load(data_raw)

# Models

## Base Model

In [4]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tensorflow.keras.losses import *
from tensorflow.keras.activations import *
class _model:
    @property
    def _dim(self):
        return int(face_size * 1.5)
    def load(self,path):
        self._main = load_model(path,custom_objects = catalog)
    def save(self,path):
        self._main.save(path)
    def train(self,input_batch,output_batch):
        return self._main.train_on_batch(input_batch,output_batch)
    def fit(self,*args,**kwargs):
        return self._main.fit(*args,**kwargs)
    def test(self,input_batch,output_batch):
        return self._main.evaluate(input_batch,output_batch)
    def predict(self,input_batch):
        return self._main.predict(input_batch)
class generator(_model):
    def _add_block(self,layer,filter,pooling = 0,kernel_size = (3,3)):
        if(pooling > 0):
            layer = Conv2D(filter,kernel_size,strides = pooling,padding = "same")(layer)
        elif(pooling < 0):
            layer = Conv2DTranspose(filter,kernel_size,strides = -pooling,padding = "same")(layer)
        else:
            layer = Conv2D(filter,kernel_size,padding = "same")(layer)
        layer = BatchNormalization()(layer)
        return LeakyReLU(0.3)(layer)
    def _add_dense_block(self,layer,units):
        layer = Dense(units)(layer)
        layer = BatchNormalization()(layer)
        return LeakyReLU(0.3)(layer)
    def __init__(self):
        # Encoder
        main_input = Input((eye_size,eye_size,3))
        dim = 32
        main = self._add_block(main_input,dim)
        for i in range(int(np.log2(eye_size) - 2)):
            dim *= 2
            main = self._add_block(main,dim,2)
            main = self._add_block(main,dim)
        main = Flatten()(main)
        main = self._add_dense_block(main,face_size)
        # Decoder
        dim = self._dim
        main = self._add_dense_block(main,16 * dim)
        main = Reshape((4,4,dim))(main)
        main = self._add_block(main,self._dim)
        main = self._add_block(main,self._dim)
        for i in range(int(np.log2(face_size) - 2)):
            dim //= 2
            main = self._add_block(main,dim,-2)
            main = self._add_block(main,dim)
        # Converter
        main = Conv2D(3,(3,3),padding = "same")(main)
        main = Activation("sigmoid")(main)
        self._main = Model(main_input,main)
        self._main.compile(optimizer = optimizer(),loss = "mae")

## WGAN

### Main

In [None]:
class WGAN(keras.Model):
    def __init__(
        self,
        discriminator,
        generator,
        discriminator_extra_steps=3,
        gp_weight=10.0,
    ):
        super(WGAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.d_steps = discriminator_extra_steps
        self.gp_weight = gp_weight

    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
        super(WGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_loss_fn = d_loss_fn
        self.g_loss_fn = g_loss_fn

    def gradient_penalty(self, batch_size, real_images, fake_images, eye_images):
        """ Calculates the gradient penalty.

        This loss is calculated on an interpolated image
        and added to the discriminator loss.
        """
        # get the interplated image
        alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)
        diff = fake_images - real_images
        interpolated = real_images + alpha * diff

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            # 1. Get the discriminator output for this interpolated image.
            pred = self.discriminator([interpolated,eye_images], training=True)

        # 2. Calculate the gradients w.r.t to this interpolated image.
        grads = gp_tape.gradient(pred, [interpolated])[0]
        # 3. Calcuate the norm of the gradients
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp

    def train_step(self, data):
        eye_images, real_images = data

        # Get the batch size
        batch_size = tf.shape(real_images)[0]

        # For each batch, we are going to perform the
        # following steps as laid out in the original paper.
        # 1. Train the generator and get the generator loss
        # 2. Train the discriminator and get the discriminator loss
        # 3. Calculate the gradient penalty
        # 4. Multiply this gradient penalty with a constant weight factor
        # 5. Add gradient penalty to the discriminator loss
        # 6. Return generator and discriminator losses as a loss dictionary.

        # Train discriminator first. The original paper recommends training
        # the discriminator for `x` more steps (typically 5) as compared to
        # one step of the generator. Here we will train it for 3 extra steps
        # as compared to 5 to reduce the training time.
        for i in range(self.d_steps):
            with tf.GradientTape() as tape:
                # Generate fake images from the latent vector
                fake_images = self.generator(eye_images, training=True)
                # Get the logits for the fake images
                fake_logits = self.discriminator([fake_images,eye_images], training=True)
                # Get the logits for real images
                real_logits = self.discriminator([real_images,eye_images], training=True)

                # Calculate discriminator loss using fake and real logits
                d_cost = self.d_loss_fn(real_img=real_logits, fake_img=fake_logits)
                # Calculate the gradient penalty
                gp = self.gradient_penalty(batch_size, real_images, fake_images, eye_images)
                # Add the gradient penalty to the original discriminator loss
                d_loss = d_cost + gp * self.gp_weight

            # Get the gradients w.r.t the discriminator loss
            d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
            # Update the weights of the discriminator using the discriminator optimizer
            self.d_optimizer.apply_gradients(
                zip(d_gradient, self.discriminator.trainable_variables)
            )

        # Train the generator now.
        with tf.GradientTape() as tape:
            # Generate fake images using the generator
            generated_images = self.generator(eye_images, training=True)
            # Get the discriminator logits for fake images
            gen_img_logits = self.discriminator([generated_images,eye_images], training=True)
            # Calculate the generator loss
            g_loss = self.g_loss_fn(gen_img_logits)

        # Get the gradients w.r.t the generator loss
        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
        # Update the weights of the generator using the generator optimizer
        self.g_optimizer.apply_gradients(
            zip(gen_gradient, self.generator.trainable_variables)
        )
        return {"d_loss": d_loss, "g_loss": g_loss}
def w_loss_g(fake_img):
    return -tf.reduce_mean(fake_img)
def w_loss_d(real_img,fake_img):
    real_loss = tf.reduce_mean(real_img)
    fake_loss = tf.reduce_mean(fake_img)
    return fake_loss - real_loss
def s_adam():
    return Adam(learning_rate=0.0002,beta_1=0.5,beta_2=0.9)
catalog = {
    "w_loss_g": w_loss_g,
    "w_loss_d": w_loss_d,
    "s_adam": s_adam,
}

NameError: ignored

In [None]:
class discriminator(_model):
    def _add_block(self,layer,filter,kernel_size = (3,3),strides = 1):
        if(strides < 1):
            layer = Conv2DTranspose(filter,kernel_size,strides = int(1 / strides),padding = "same")(layer)
        else:
            layer = Conv2D(filter,kernel_size,strides = int(strides),padding = "same")(layer)
        return LeakyReLU(0.3)(layer)
    def __init__(self):
        main_input = Input((face_size,face_size,3))
        cond_input = Input((eye_size,eye_size,3))
        cond = self._add_block(cond_input,3,(3,3),eye_size / face_size)
        main = concatenate([main_input,cond])
        dim = 16
        main = self._add_block(main,dim,(3,3))
        main = self._add_block(main,dim,(3,3))
        for i in range(int(np.log2(face_size) - 2)):
            dim *= 2
            main = self._add_block(main,dim,(3,3),2)
            main = self._add_block(main,dim,(3,3))
        main = Flatten()(main)
        main = Dense(1)(main)
        self._main = Model([main_input,cond_input],main)
class combined(_model):
    def __init__(self,generator,discriminator):
        self._main = WGAN(discriminator = discriminator._main,generator = generator._main)
        self._main.compile(
            d_optimizer = s_adam(),
            g_optimizer = s_adam(),
            d_loss_fn = w_loss_d,
            g_loss_fn = w_loss_g
        )
    def save(self,path):
        self._main.generator.save(path + "_generator.h5")
        self._main.discriminator.save(path + "_discriminator.h5")

### Load Model

In [None]:
g = generator()
g.load(model_path + "wgan_generator.h5")
d = discriminator()
d.load(model_path + "wgan_discriminator.h5")
c = combined(g,d)

### New Model

In [None]:
g = generator()
d = discriminator()
c = combined(g,d)

## Feature Loss

### Main

In [5]:
import tensorflow.keras.backend as K
from tensorflow.keras.applications import VGG19
class vgg(_model):
    @staticmethod
    def preprocess(img):
        return img * 255 - [103.939,116.779,123.68]
    def __init__(self):
        main_input = Input((face_size,face_size,3))
        main = Lambda(self.preprocess)(main_input)
        vgg = VGG19(input_shape = (face_size,face_size,3),include_top = False)
        vgg_output = []
        for l in vgg.layers:
            vgg_output.append(l.output)
        vgg = Model(vgg.input,vgg_output)
        vgg = vgg(main)
        self._main = Model(main_input,vgg)
        self._main.compile(optimizer = optimizer(),loss = "mae")
        self._main.trainable = False
class combined_loss(Loss):
    def __init__(self,mae_weight = 0.85,sdl_weight = 0.15,**kwargs):
        super().__init__(**kwargs)
        self.mae_weight = mae_weight
        self.sdl_weight = sdl_weight
    @staticmethod
    def _vector_std(matrix):
        return K.mean(K.mean((cosine_similarity(K.mean(matrix,axis = 0),matrix) + 1) / 2,axis = 0))
    @classmethod
    def vector_std_loss(cls,y_true,y_pred):
        return K.abs(cls._vector_std(y_true) - cls._vector_std(y_pred))
    def call(self,y_true,y_pred):
        y_true = K.constant(y_true)
        y_pred = K.constant(y_pred)
        return (
            MeanAbsoluteError()(y_true,y_pred) * self.mae_weight +
            self.vector_std_loss(y_true,y_pred) * self.sdl_weight
        )
    def get_config(self):
        return {
            "mae_weight": self.mae_weight,
            "sdl_weight": self.sdl_weight,
        }
    @classmethod
    def from_config(cls,config):
        return cls(**config)
def weight_function_linear(sum,layer_n,focus):
    layer_n -= 1
    max = 2 * sum / layer_n
    def weight_function(i):
        if 0 <= i <= focus:
            return max * i / focus
        elif focus <= i <= layer_n:
            return (max / (layer_n - focus)) * (layer_n - i)
    return weight_function
class combined(_model):
    class _base_model(Model):
        def __init__(self,inputs,outputs,**kwargs):
            super().__init__(inputs,outputs,**kwargs)
        def train_step(self,data):
            x,y = data
            return super().train_step((x,[y] + self.layers[2](y)))
        def evaluate(self,x,y,**kwargs):
            return super().evaluate(x,[y] + self.layers[2](y),**kwargs)
        def compile(self,optimizer,**kwargs):
            kwargs["run_eagerly"] = True
            return super().compile(optimizer,**kwargs)
        
    def __init__(self,generator,vgg,mae_loss_weight = 0.85,sdl_loss_weight = 0.15,loss_weight_balance = True,data_sample = None,g_loss_weight = 0.1,vgg_loss_weight = 0.9,loss_weight_function = weight_function_linear,loss_weight_args = {"focus": 10}):
        # Combine the generator and the vgg model
        main_input = Input((eye_size,eye_size,3))
        main = generator._main(main_input)
        main_vgg = vgg._main(main)
        self._main = self._base_model(main_input,[main] + main_vgg)
        # Estimate the scale of losses
        if loss_weight_balance == True:
            loss_list = []
            self._main.compile(optimizer = optimizer(),loss = "mae")
            mae_loss_base = self.test(*data_sample)
            self._main.compile(optimizer = optimizer(),loss = combined_loss.vector_std_loss)
            sdl_loss_base = self.test(*data_sample)
            for i in range(len(mae_loss_base)):
                loss_list.append(combined_loss(mae_loss_weight / mae_loss_base[i],sdl_loss_weight / sdl_loss_base[i]))
        else:
            loss_list = combined_loss(mae_loss_weight,sdl_loss_weight)
        # Determine the weightings of the losses
        loss_weight = []
        loss_weight.append(g_loss_weight)
        layer_n = len(main_vgg)
        loss_weight_function = loss_weight_function(vgg_loss_weight,layer_n,**loss_weight_args)
        for i in range(layer_n):
            loss_weight.append(loss_weight_function(i))
        self._main.compile(optimizer = optimizer(),loss = loss_list,loss_weights = loss_weight)
        print("Loss weights: " + str(loss_weight))
catalog = {
    "combined_loss": combined_loss,
    "_base_model": combined._base_model,
    "preprocess": vgg.preprocess,
    optimizer_name: optimizer,
}

### New Model

In [6]:
g = generator()
v = vgg()
c = combined(g,v,mae_loss_weight = 0.85,sdl_loss_weight = 0.15,g_loss_weight = 0,vgg_loss_weight = 1,loss_weight_balance = True,data_sample = [data[0][:32],data[1][:32]])

Loss weights: [0, 0.0, 0.009523809523809523, 0.019047619047619046, 0.02857142857142857, 0.03809523809523809, 0.047619047619047616, 0.05714285714285714, 0.06666666666666667, 0.07619047619047618, 0.08571428571428572, 0.09523809523809523, 0.08658008658008658, 0.07792207792207792, 0.06926406926406926, 0.06060606060606061, 0.05194805194805195, 0.04329004329004329, 0.03463203463203463, 0.025974025974025976, 0.017316017316017316, 0.008658008658008658, 0.0]


### Load Model

In [13]:
c = load_model(model_path + "main.h5",custom_objects = catalog)

In [17]:
c.save_weights("weight.h5")
with open("optimizer.pkl","wb") as f:
    pickle.dump(c.optimizer.get_weights(),f)

In [7]:
with open("optimizer.pkl","rb") as f:
    opt_weights = pickle.load(f)
grad_vars = c._main.trainable_weights
zero_grads = [tf.zeros_like(w) for w in grad_vars]
c._main.optimizer.apply_gradients(zip(zero_grads, grad_vars))
c._main.optimizer.set_weights(opt_weights)
c._main.load_weights("weight.h5")

# Training

In [None]:
import os, cv2
def test(generator):
    for raw_file_name in os.listdir(test_raw_path):
        try:
            data_item = crop(test_raw_path + raw_file_name)
        except:
            continue
        eye = np.array(data_item[0])
        face = np.array(data_item[1])
        generator.test(eye,face)
        face_pred = generator.predict(eye)
        for i in range(2):
            cv2.imwrite(output_path + raw_file_name + "_eye_" + str(i) + ".jpg",restore(eye[i]))
            cv2.imwrite(output_path + raw_file_name + "_face_" + str(i) + ".jpg",restore(face[i]))
            cv2.imwrite(output_path + raw_file_name + "_pred_" + str(i) + ".jpg",restore(face_pred[i]))
for i in range(13):
    print("Epoch " + str(i + 1) + ":")
    c.fit(data[0],data[1],batch_size = 64)
    test(g)
    c.save(model_path + "main.h5")

Epoch 1:
Epoch 2:
Epoch 3:
Epoch 4:
Epoch 5:
Epoch 6:
Epoch 7:
Epoch 8:
Epoch 9:
Epoch 10:
Epoch 11:

In [None]:
test()