# StyleGAN linear walk - Color in W latent space

## 1. Setup

In [None]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=0

In [None]:
%cd ..

Pick a pretrained model to use here, and set the output directories for trained weights.

In [None]:
# cars
output_dir = 'notebooks/models/stylegan_W_color_car'
model_url = 'https://drive.google.com/uc?id=1MJ6iCfNtMIRicihwRorsM3b7mmtmK9c3'

# # ffhq faces
# output_dir = 'notebooks/models/stylegan_W_color_face'
# model_url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ'

# # cats
# output_dir = 'notebooks/models/stylegan_W_color_cat'
# model_url = 'https://drive.google.com/uc?id=1MQywl0FNt6lHu8E_EUqnRbviagS7fbiJ'

Pick learning rate and number of samples.

In [None]:
lr = 0.001
num_samples = 2000

## 2. Create Graph and initialize

In [None]:
# make output directory
import os
os.makedirs(os.path.join(output_dir, 'images'), exist_ok=True)
os.makedirs(os.path.join(output_dir, 'output'), exist_ok=True)

In [None]:
import tensorflow as tf
import os
import pickle
import numpy as np
import PIL.Image
import time
from resources import tf_lpips_pkg as lpips_tf

# this is mostly to solve pickle issues
import sys
sys.path.append('resources/stylegan')
import dnnlib
import dnnlib.tflib as tflib
import config


tflib.init_tf()

with dnnlib.util.open_url(model_url, cache_dir=config.cache_dir) as f:
    _G, _D, Gs = pickle.load(f)
    # _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run.
    # _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run.
    # Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot.

In [None]:
Nsliders = 3 # we use 3 slider dimensions for RGB color

dim_z = Gs.input_shape[1]

# get original generated output
z = tf.placeholder(tf.float32, shape=(None, dim_z))
outputs_orig = tf.transpose(Gs.get_output_for(z, None, is_validation=True, 
                                              randomize_noise=True), [0, 2, 3, 1])

img_size = outputs_orig.shape[1]
Nchannels = outputs_orig.shape[3]

# set target placeholders
target = tf.placeholder(tf.float32, shape=(None, img_size, img_size, Nchannels))
mask = tf.placeholder(tf.float32, shape=(None, img_size, img_size, Nchannels))

# forward to W latent space
out_dlatents = Gs.components.mapping.get_output_for(z, None) #out_dlatents shape: [?, 16, 512]

# set slider and learnable walk vector
latent_dim = out_dlatents.shape
alpha = tf.placeholder(tf.float32, shape=(None, Nsliders))
w = tf.Variable(np.random.normal(0.0, 0.1, [1, latent_dim[1], latent_dim[2], Nsliders]), name='walk_intermed', dtype=np.float32)

# apply walk
out_dlatents_new = out_dlatents
for i in range(Nsliders):
    out_dlatents_new = out_dlatents_new + tf.reshape(
        tf.expand_dims(alpha[:,i], axis=1)* tf.reshape(w[:,:,:,i], (1, -1)), (-1, 16, z.shape[1]))

# get output after applying walk
transformed_output = tf.transpose(Gs.components.synthesis.get_output_for(
    out_dlatents_new, is_validation=True, randomize_noise=True), [0, 2, 3, 1])

# L_2 loss
loss = tf.losses.compute_weighted_loss(tf.square(transformed_output-target), weights=mask)

# Lpips loss
loss_lpips = tf.reduce_mean(lpips_tf.lpips(mask*transformed_output, mask*target, 
                                                  model='net-lin', net='alex'))

In [None]:
# ops to rescale the stylegan output range ([-1, 1]) to uint8 range [0, 255]
float_im = tf.placeholder(tf.float32, outputs_orig.shape)
uint8_im = tflib.convert_images_to_uint8(tf.convert_to_tensor(float_im, dtype=tf.float32))

In [None]:
def initialize_uninitialized(sess):
    global_vars          = tf.global_variables()
    is_not_initialized   = sess.run([tf.is_variable_initialized(var) for var in global_vars])
    not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]

    print([str(i.name) for i in not_initialized_vars])
    if len(not_initialized_vars):
        sess.run(tf.variables_initializer(not_initialized_vars))
        return not_initialized_vars

In [None]:
sess = tf.get_default_session()
not_initialized_vars = initialize_uninitialized(sess)

# change to loss_lpips to optimize using lpips loss instead
train_step = tf.train.AdamOptimizer(lr).minimize(loss, var_list=not_initialized_vars, 
                                                 name='AdamOpter')

# this time init Adam's vars:
not_initialized_vars = initialize_uninitialized(sess)

## 3. Define Target Operation

In [None]:
def get_target_np(outputs_zs, alpha):
        
    if not np.any(alpha): # alpha is all zeros
        return outputs_zs, np.ones(outputs_zs.shape)
    
    assert(outputs_zs.shape[0] == alpha.shape[0])
    
    target_fn = np.copy(outputs_zs)
    for b in range(outputs_zs.shape[0]):
        for i in range(3):
            target_fn[b,:,:,i] = target_fn[b,:,:,i]+alpha[b,i]

    mask_out = np.ones(outputs_zs.shape)
    return target_fn, mask_out

## 4. Train walk

In [None]:
saver = tf.train.Saver(tf.trainable_variables(scope='walk'))

In [None]:
import logging
import sys
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(threadName)-12.12s] [%(levelname)-5.5s]  %(message)s",
    handlers=[
        logging.FileHandler("{0}/{1}.log".format(output_dir, 'train')),
        logging.StreamHandler(sys.stdout)
    ])
logger = logging.getLogger()

loss_vals = []

# train
def train(saver):
    random_seed = 0
    rnd = np.random.RandomState(random_seed)
    zs = rnd.randn(num_samples, dim_z)

    Loss_sum = 0
    Loss_sum_iter = 0
    n_epoch = 1
    optim_iter = 0
    batch_size = 4
    for epoch in range(n_epoch):
        for batch_start in range(0, num_samples, batch_size):
            start_time = time.time()

            alpha_val = np.random.random(size=(batch_size, Nsliders))-0.5

            s = slice(batch_start, min(num_samples, batch_start + batch_size))

            feed_dict_out = {z: zs[s]}
            out_zs = sess.run(outputs_orig, feed_dict_out) 
            target_fn, mask_out = get_target_np(out_zs, alpha_val)
            
            feed_dict = {z: zs[s], alpha: alpha_val, target: target_fn, mask: mask_out}
            curr_loss, _ = sess.run([loss, train_step], feed_dict=feed_dict)
            Loss_sum = Loss_sum + curr_loss
            Loss_sum_iter = Loss_sum_iter + curr_loss
            
            elapsed_time = time.time() - start_time

            logger.info('T, epc, bst, lss, a: {}, {}, {}, {}, {}'.format(elapsed_time, epoch, batch_start, curr_loss, alpha_val))

            if (optim_iter % 2500 == 0) and (optim_iter > 0):
                saver.save(sess, '{}/{}/model_{}.ckpt'.format(output_dir, 'output', optim_iter*batch_size), write_meta_graph=False, write_state=False)
            
            if (optim_iter % 100 == 0) and (optim_iter > 0):
                loss_vals.append(Loss_sum_iter/(100*batch_size))
                Loss_sum_iter = 0
                print('Loss:', loss_vals)

            optim_iter = optim_iter+1
            
        if optim_iter > 0:
            loss_vals.append(Loss_sum_iter/(100*batch_size))
            print('average loss with this metric: ', Loss_sum/(optim_iter*batch_size))
    saver.save(sess, '{}/{}/model_{}.ckpt'.format(output_dir, 'output', optim_iter*batch_size), write_meta_graph=False, write_state=False)
    

In [None]:
train(saver)

## 5. Visualizations

In [None]:
from utils.image import imgrid, imshow

In [None]:
# show learned samples

num_samples_vis = 6
batch_size = 1
a = np.linspace(0, 1, 6)

random_seed = 0
rnd = np.random.RandomState(random_seed)
zs = rnd.randn(num_samples_vis, dim_z)

for batch_num, batch_start in enumerate(range(0, num_samples_vis, batch_size)):
    
    ims = []
    targets = []
    
    s = slice(batch_start, min(num_samples, batch_start + batch_size))

    input_test = {z:zs[s]}

    out_input_test = sess.run(outputs_orig, input_test)

    for i in range(a.shape[0]):
        alpha_val = np.ones((zs[s].shape[0], Nsliders)) * -a[i]
        alpha_val[:, 1] = a[i]
        target_fn,_ = get_target_np(out_input_test, alpha_val) #, show_img=True, show_mask=True)
        im_out = sess.run(transformed_output, {z: zs[s], alpha: alpha_val})
        
        # rescale
        im_out = sess.run(uint8_im, {float_im: im_out})
        target_fn = sess.run(uint8_im, {float_im: target_fn})
        
        ims.append(im_out)
        targets.append(target_fn)
    im_stack = np.concatenate(targets + ims).astype(np.uint8)
    imshow(imgrid(im_stack, cols = len(a)))

In [None]:
# plot losses 
import matplotlib.pyplot as plt
%matplotlib inline
plt.plot(loss_vals)
plt.xlabel('num samples, lr{}'.format(lr))
plt.ylabel('Loss')
plt.show()