In [None]:
import os 
import math
import cv2
import numpy as np
import pathlib
import tensorflow as tf
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
from PIL import Image

from IPython.display import clear_output
from tensorflow.keras.layers import Conv2D, Dense, LeakyReLU, BatchNormalization, UpSampling2D, Add,\
    AveragePooling2D, Concatenate, Input, Lambda, Activation, LayerNormalization
from tensorflow.keras import Model, Sequential

In [None]:
%load_ext tensorboard

In [None]:
BATCH_SIZE = 32
IMG_SIZE = 64
N_BLOCK = 2
EPOCHS = 100
CURRENT_EPOCH = 1
SAVE_EVERY_N_EPOCH = 5

LOG_DIR = './results/logs/'
CKPT_DIR = './results/models_weight'

In [None]:
inp_data_path = pathlib.Path('/kaggle/input/another-anime-face-dataset/animefaces256cleaner')
file_list = [str(path) for path in inp_data_path.glob('*.jpg')]

def preprocess(file_path, img_size=IMG_SIZE):
    imgs = tf.io.read_file(file_path)
    imgs = tf.io.decode_jpeg(imgs, channels=3)
    imgs = tf.image.resize(imgs, [img_size, img_size])

    imgs = tf.image.convert_image_dtype(imgs, dtype=tf.float32)
    imgs = (imgs - 127.5) / 127.5
    return imgs


data_path = tf.data.Dataset.from_tensor_slices(file_list)
train_data = data_path.map(preprocess).shuffle(500).batch(BATCH_SIZE)
test_data = data_path.map(preprocess).shuffle(500).batch(16)

img = next(iter(train_data))

plt.imshow(img[0])
plt.show()


In [None]:
timesteps = 1000

# create a fixed beta schedule
beta = np.linspace(0.0001, 0.02, timesteps)

# this will be used as discussed in the reparameterization trick
alpha = 1 - beta
alpha_bar = np.cumprod(alpha, 0)
alpha_bar = np.concatenate((np.array([1.]), alpha_bar[:-1]), axis=0)
sqrt_alpha_bar = np.sqrt(alpha_bar)
one_minus_sqrt_alpha_bar = np.sqrt(1-alpha_bar)


def add_noise(x_0, t):
    noise = np.random.normal(size=x_0.shape)
    sqrt_alpha_bar_t = np.reshape(np.take(sqrt_alpha_bar, t), [-1, 1, 1, 1])
    one_minus_sqrt_alpha_bar_t = np.reshape(np.take(one_minus_sqrt_alpha_bar, t), [-1, 1, 1, 1])
    noisy_img = sqrt_alpha_bar_t  * x_0 + one_minus_sqrt_alpha_bar_t  * noise
    return noisy_img, noise


fig = plt.figure(figsize=(15, 30))

for index, i in enumerate([10, 100, 300, 600]):
    noisy_im, noise = add_noise(img[0], np.array([i,]))
    noisy_im = np.squeeze(noisy_im)
    plt.subplot(1, 4, index+1)
    plt.axis('off')
    plt.imshow(noisy_im)

plt.savefig('image.png')
plt.show()

In [None]:
class LinearAttention(tf.keras.layers.Layer):
    def __init__(self, dim, heads=4, dim_head=32):
        super(LinearAttention, self).__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        self.dim_head = dim_head
        self.hidden_dim = dim_head * heads

        self.to_qkv = Conv2D(filters=self.hidden_dim * 3, kernel_size=1, strides=1, use_bias=False)

        self.to_out = Sequential([
            Conv2D(filters=dim, kernel_size=1, strides=1),
            LayerNormalization()
        ])

    def call(self, x, training=True):
        residual = x
        b, h, w, c = x.shape
        
        qkv = self.to_qkv(x)
        q, k, v = tf.split(qkv, num_or_size_splits=3, axis=-1)
        
        q = tf.reshape(q, [-1, self.heads, self.dim_head, h*w])
        k = tf.reshape(k, [-1, self.heads, self.dim_head, h*w])
        v = tf.reshape(v, [-1, self.heads, self.dim_head, h*w])
        
        q = tf.nn.softmax(q, axis=-2)
        k = tf.nn.softmax(k, axis=-1)
        q = q * self.scale
        context = tf.einsum('b h d n, b h e n -> b h d e', k, v)

        out = tf.einsum('b h d e, b h d n -> b h e n', context, q)
        out = tf.reshape(out, [-1, h, w, self.hidden_dim])
        out = self.to_out(out, training=training)

        return out + residual

In [None]:
def resBlock(inp, filter, t):
    if inp.shape[-1] == filter:
        residual = inp
    else:
        residual = Conv2D(filter,  1, 1, padding='same')(inp)
    
    x = Conv2D(filter, 3, 1, padding='same')(inp)
    #x = BatchNormalization()(x)
    x = tfa.layers.GroupNormalization(8, epsilon=1e-05)(x)
    # It is common to use group norm in the literature
    
    t = Dense(filter*2)(t)
    gamma, beta= tf.split(t, num_or_size_splits=2, axis=-1)
    x = x * (gamma + 1) + beta
    x = Activation('swish')(x)
    
    x = Conv2D(filter, 3, 1, padding='same')(x)
    #x = BatchNormalization()(x)
    x = tfa.layers.GroupNormalization(8, epsilon=1e-05)(x)
    x = Activation('swish')(x)
    out = Add()([x, residual])
    return out

def downBlock(x, skips, filter, n_block, t):
    for _ in range(n_block):
        x = resBlock(x, filter, t)
        skips.append(x)
        
    # adding the attention layer will have better global coherence
    x = LinearAttention(filter)(x) 
    out = AveragePooling2D()(x)
    return out
    
def upBlock(x, skips, filter, n_block, t):
    x = UpSampling2D()(x)
    for _ in range(n_block):
        x = Concatenate()([x, skips.pop()])
        x = resBlock(x, filter, t)
        
    x = LinearAttention(filter)(x)
    return x

# time embedding for the discrete time schedule
class SinusoidalPosEmb(tf.keras.layers.Layer):
    def __init__(self, dim=32, max_positions=10000):
        super(SinusoidalPosEmb, self).__init__()
        self.dim = dim
        self.max_positions = max_positions
        self.dense1 = Dense(self.dim, activation='swish')
        self.dense2 = Dense(self.dim, activation='swish')
        
    def call(self, x, training=True):
        x = tf.cast(x, tf.float32)
        half_dim = self.dim // 2
        emb = math.log(self.max_positions) / (half_dim - 1)
        emb = tf.exp(tf.range(half_dim, dtype=tf.float32) * -emb)
        emb = x * emb[None, :]

        emb = tf.concat([tf.sin(emb), tf.cos(emb)], axis=-1)
        emb = tf.reshape(emb, [-1, 1, 1, self.dim])
        emb = self.dense1(emb)
        emb = self.dense2(emb)
        return emb 

In [None]:
def U_NET(img_size):
    noisy_img = Input(img_size)
    time = Input((1,))
    
    e = SinusoidalPosEmb()(time)
    x = Conv2D(32, 1, 1, padding='same')(noisy_img)
    skips = []
    
    x = downBlock(x, skips, 64, N_BLOCK, e)
    x = downBlock(x, skips, 128, N_BLOCK, e)
    x = downBlock(x, skips, 128, N_BLOCK, e)

    for _ in range(N_BLOCK):
        x = resBlock(x, 256, e)
        
    x = upBlock(x, skips, 128, N_BLOCK, e)
    x = upBlock(x, skips, 128, N_BLOCK, e)
    x = upBlock(x, skips, 64, N_BLOCK, e)
    x = resBlock(x, 32, e)
    
    out = Conv2D(3, 1, 1, padding='same', kernel_initializer="zeros")(x)
    return Model([noisy_img, time], out)

u_net = U_NET((64, 64, 3))
#u_net.summary()

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
loss_fn = tf.keras.losses.MeanSquaredError()

ckpt = tf.train.Checkpoint(u_net=u_net)                          
# save model

summary_Writer = tf.summary.create_file_writer(LOG_DIR)
ckpt_manager = tf.train.CheckpointManager(ckpt, CKPT_DIR, max_to_keep=5)

if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    latest_epoch = int(ckpt_manager.latest_checkpoint.split('-')[-1])
    CURRENT_EPOCH = latest_epoch * SAVE_EVERY_N_EPOCH + 1
    print ('Latest checkpoint of epoch {} restored!!'.format(CURRENT_EPOCH))

In [None]:
OUTPUT_PATH = r'./models/git_out'

if not os.path.exists(OUTPUT_PATH):
    os.mkdir(OUTPUT_PATH)

def ddim(x_t, pred_noise, t, step_size):
    alpha_t_bar = np.reshape(np.take(alpha_bar, t), [-1, 1, 1, 1])
    alpha_t_minus_one = np.reshape(np.take(alpha_bar, t-step_size), [-1, 1, 1, 1])
        
    pred = (x_t - ((1 - alpha_t_bar) ** 0.5) * pred_noise)/ (alpha_t_bar ** 0.5)
    pred = (alpha_t_minus_one ** 0.5) * pred

    pred = pred + ((1 - alpha_t_minus_one) ** 0.5) * pred_noise
    return pred

inference_timesteps = 200
inference_range = range(0, timesteps, timesteps // inference_timesteps)
inf_step = timesteps // inference_timesteps

def generate_save_img(epoch, path=OUTPUT_PATH, save=True):
    x = tf.random.normal((16,64,64,3))
    for index, i in enumerate(reversed(range(inference_timesteps))):
        t = np.repeat(inference_range[i], 16)
        
        pred_noise = u_net([x, t])
        x = ddim(x, pred_noise, t, inf_step)
        
        if any(t-inf_step) == 0:
            break
    
    for i in range(x.shape[0]):
        axs = plt.subplot(4, 4, i+1)
        axs.imshow(x[i] * 0.5 + 0.5)
        plt.axis('off') 

    if save:
        plt.savefig(os.path.join(path, 'image_at_epoch_{:04d}.png'.format(epoch)))
    plt.show() 


In [None]:
def train_step(inp_img):
    t_size = inp_img.shape[0]
    t = tf.random.uniform(shape=[t_size,], minval=0, maxval=timesteps, dtype=tf.int32)
    noisy_img, noise = add_noise(inp_img, t)

    with tf.GradientTape() as tape:
        pre_noise = u_net([noisy_img, t])
        loss = loss_fn(pre_noise, noise)
    
    gradients = tape.gradient(loss, u_net.trainable_variables)
    optimizer.apply_gradients(zip(gradients, u_net.trainable_variables))
    return loss


In [None]:
import time

for epoch in range(CURRENT_EPOCH, EPOCHS+1):

    start = time.time()
    print('Start of epoch {}'.format(epoch))
    
    losses = []
    for step, data in enumerate(train_data):

        loss = train_step(data)
        losses.append(loss)

        if step % 100 == 0:
            print('.', end='')

        if step > 1000:
            break

    
    print('\n Epoch {} finished ~ ~ ~ ~ ~'.format(epoch))
    with summary_Writer.as_default():
        tf.summary.scalar('loss', np.mean(losses), step=epoch)

    if epoch % SAVE_EVERY_N_EPOCH == 0:
        clear_output(wait=True)
    
        #save model
        ckpt_save_path = ckpt_manager.save()
        print ('Saving checkpoint for epoch {} at {}'.format(epoch, ckpt_save_path))
    
    print ('epoch {} loss is {} \n'.format(epoch,np.mean(losses))) 
    print ('Time taken for epoch {} is {} sec\n'.format(epoch,time.time()-start))                                             
    generate_save_img(epoch)
