[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/HSinger04/VOGUE-Reimplementation/blob/main/cryu854/TryOn.ipynb)

In [None]:
%cd /content
!git clone https://github.com/HSinger04/VOGUE-Reimplementation

/content
Cloning into 'VOGUE-Reimplementation'...
remote: Enumerating objects: 191, done.[K
remote: Counting objects: 100% (191/191), done.[K
remote: Compressing objects: 100% (190/190), done.[K
remote: Total 494 (delta 121), reused 0 (delta 0), pack-reused 303[K
Receiving objects: 100% (494/494), 84.58 MiB | 31.05 MiB/s, done.
Resolving deltas: 100% (287/287), done.


In [None]:
%cd /content/VOGUE-Reimplementation/cryu854/

/content/VOGUE-Reimplementation/cryu854


## Mount drive for dataset and weights

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

## Imports

In [None]:
import os
import numpy as np

import tensorflow as tf
from tensorflow.keras.layers import \
BatchNormalization, ELU, Dense
from tensorflow.keras import Model

## Global setting whether to use tf.function or not

In [None]:
# Set true for better performance, False for debugging
TF_FUNCTION = True

## Load data

In [None]:
# Pro: Save space in google drive
# Contra: Needs to load the whole dataset every time
# use list_files to get generate dataset from data in different subdirectories easily
data = tf.data.Dataset.list_files("/content/drive/MyDrive/Lernen/Coxi/IANNwTF/ffhq-dataset/images1024x1024/" 
                                  + "[0-9]" * 2
                                  + "000/*.png")

## Data pipeline

In [None]:
# Data pipeline constants
# shuffle_size shouldn't be much higher than 100. Otherwise, Google Colab runs out of memory
shuffle_size = 100
BATCH_SIZE = 4
PREFETCH_SIZE = tf.data.experimental.AUTOTUNE

In [None]:
def decode_ffhq(image_path):
    img = tf.io.read_file(image_path)
    img = tf.image.decode_png(img, channels=3)
    return img

# Convert paths to image
data = data.map(decode_ffhq)

In [None]:
# TODO: Show some images

In [None]:
# Normalize to [-1, 1] to match StyleGAN2's generator's output
data = data.map(lambda x: (tf.cast(x, tf.float32) / 127.5) - 1)
# Shuffle whole dataset once
data = data.shuffle(buffer_size=shuffle_size)
data = data.batch(BATCH_SIZE)
# Shuffle batch each iteration 
# TODO: left out for now due to memory issues
#data = data.shuffle(buffer_size=shuffle_size, reshuffle_each_iteration=True)
data = data.prefetch(PREFETCH_SIZE)

# Loading trained StyleGAN2 Generator

## Actually load the model

In [None]:
!git pull
import tensorflow as tf
from modules.generator import generator

resolution = 1024  
config = "f"
num_labels = 0
checkpoint_path = "/content/drive/MyDrive/Lernen/Coxi/IANNwTF/official_1024x1024/"

Gs = generator(resolution, num_labels, config, randomize_noise=False)
ckpt = tf.train.Checkpoint(generator_clone=Gs)
print(f'Loading network from {checkpoint_path}...')
ckpt.restore(tf.train.latest_checkpoint(checkpoint_path)).expect_partial()
# Freeze Generator since we don't want to train it
Gs.trainable = False

## Generate and show images

In [None]:
# returns image in the correct range of 0 to 225
def get_img(x, truncation_psi, training=False):
    # TODO: change comment below better
    # Generator returns values that should be clipped to -1 and 1
    img = Gs(x, truncation_psi=truncation_psi, training=training)
    img = tf.clip_by_value(img, clip_value_min=-1.0, clip_value_max=1.0)
    return img

if TF_FUNCTION:
    get_img = tf.function(get_img)    

In [None]:
truncation_psi = 0.5
latent_size = 512
latents = tf.random.normal([BATCH_SIZE, latent_size])
# TODO: from _get_labels
labels_indice = [0]*BATCH_SIZE
labels = tf.zeros([BATCH_SIZE, 0], tf.float32)
# Generate images
images = get_img([latents, labels], truncation_psi)

In [None]:
from matplotlib import pyplot as plt

# TODO: proper image displaying
for i in range(BATCH_SIZE):
    temp = images[i]
    temp = (temp + 1) * 127.5
    plt.imshow(temp.numpy().astype(np.uint8))
    break

## Define code2code model

In [None]:
# TODO: Maybe leave out. 'w' itself might already be disentangled enough
# latent_size = 512

# class Code2Code(Model):
#     def __init__(hidden_dim, initializer=tf.keras.initializers.GlorotUniform):
#         self.layers = [Dense(hidden_dim), BatchNormalization(), ELU(),
#                        Dense(512), BatchNormalization(), ELU()]          
    
    
#   @tf.function 
#   def call(self, x, training=True):
#     for layer in self.layers:
#       x = layer(x, training = training)

#     return x




## Training preparation

## Load VGG and freeze it 

In [None]:
# TODO: remove unnecessary layers. Also, summarize this in external code instead of 
# creating this cell in both notebooks
perc_base_net = tf.keras.applications.EfficientNetB0()
# Freeze perc_base_net since we don't want to train it
perc_base_net.trainable = False

In [None]:
layer_names = ("block1a_project_conv", "block2b_project_conv", "block3b_project_conv", 
               "block4c_project_conv", "block5c_project_conv", "block6d_project_conv",
               "block7a_project_conv"
               )

# layer_indices = []
used_layers = []

# for i, layer in enumerate(vgg.layers):
#     used_layers.append(layer)
#     if layer.name in layer_names:
#         layer_indices.append(i)
#         if len(layer_indices) == len(layer_names):
#             break

used_layers = [perc_base_net.get_layer(layer_name).output for layer_name in layer_names]

perc_net = tf.keras.Model(inputs=perc_base_net.inputs, outputs=used_layers)

# Hopefully save memory this way         
del perc_base_net
tf.keras.backend.clear_session()

NUM_LAYERS = len(layer_names)

## Define perceptual loss

In [None]:
def perc_loss(real, fake, loss_tracker):
    # TODO: Better description. Also maybe more detail? Point to equation
    """Returns perceptual loss according to VGG16 activations. See 
    """
    
    real = tf.image.resize(real, [224, 224])
    real = tf.keras.applications.resnet_v2.preprocess_input(real)

    fake = tf.image.resize(fake, [224, 224])
    fake = tf.keras.applications.resnet_v2.preprocess_input(fake)
    
    for i, layer in enumerate(vgg.layers):
            real = layer(real)
            fake = layer(fake)
            # TODO: say which corresponds to what
            if i == 2 or i == 5 or i == 9 or i == 13 or i == 17:
                # normalize in channel dimension
                layer_loss = tf.math.l2_normalize(real, axis=-1)
                layer_loss -= tf.math.l2_normalize(fake, axis=-1)

                # TODO: should be alright, since shape is right, but not fully confirmed
                layer_loss = tf.norm(layer_loss, axis=-1)
                layer_loss = tf.square(layer_loss)
                layer_loss = tf.reduce_mean(layer_loss)

                loss_tracker.update_state(layer_loss)

if TF_FUNCTION:
    perc_loss = tf.function(perc_loss)

In [None]:
def encoder_train_step(model, train_data, optimizer, global_loss_tracker, local_loss_tracker, train_writer):

    for inputs in train_data:
        
        # reset local loss
        local_loss_tracker.reset_states()    

        with tf.GradientTape() as tape:

            # perc
            fakes = model(inputs)
            fakes = get_img([fakes, labels], truncation_psi)
            perc_loss(inputs, fakes, local_loss_tracker)
            loss = local_loss_tracker.result()
            # average over the batch manually
            # TODO: remove
            #loss = tf.math.reduce_mean(loss)
            # TODO: remove
            print(model.__class_)
            gradients = tape.gradient(loss, model.trainable_variables)

        # update weights  
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        # record global loss
        global_loss_tracker.update_state(loss)

if TF_FUNCTION:
    encoder_train_step = tf.function(encoder_train_step)   

## Instantiate loss trackers

In [None]:
GLOBAL_LOSS_TRACKER = tf.keras.metrics.Mean()
LOCAL_LOSS_TRACKER = tf.keras.metrics.Sum()

## Instantiate train writer

In [None]:
import datetime
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
# Set log directory
LOG_DIR = "/content/drive/MyDrive/Lernen/Coxi/IANNwTF/logs/encoder/" + current_time 

TRAIN_WRITER = tf.summary.create_file_writer(LOG_DIR)

## Instantiate optimizer

In [None]:
LEARNING_RATE = 0.001
OPTIMIZER = tf.keras.optimizers.Adam(LEARNING_RATE)

## Instantiate encoder

In [None]:
# TODO

In [None]:
#for datum in data:
#    print(datum.shape)

In [None]:
# TODO: remove
#for datum in data:
#    ENCODER(datum)
#    break

In [None]:
## Instantiate q vectors

# Train

In [None]:
NUM_EPOCHS = 10000

for epoch in range(NUM_EPOCHS):
    print("Start epoch: " + str(epoch))

    # reset statistics
    GLOBAL_LOSS_TRACKER.reset_states()

    encoder_train_step(ENCODER, data, OPTIMIZER, GLOBAL_LOSS_TRACKER, LOCAL_LOSS_TRACKER, TRAIN_WRITER)    
    data = data.shuffle(buffer_size = shuffle_size)

    # write average epoch loss
    with TRAIN_WRITER.as_default():
        tf.summary.scalar('loss', GLOBAL_LOSS_TRACKER.result(), step=epoch)    

Start epoch: 0
