In [None]:
# %env CUDA_DEVICE_ORDER=PCI_BUS_ID
# %env CUDA_VISIBLE_DEVICES=0,1

In [None]:
# from tensorflow.python.client import device_lib

# def get_available_gpus():
#     local_device_protos = device_lib.list_local_devices()
#     return [x.name for x in local_device_protos if x.device_type == 'GPU']

# print(get_available_gpus())

In [None]:
## These two following cells are the only neccessary cells you need to modify compare to 
## StyleGAN linear walk - Color in W latent space notebook (refer to the same directory)

In [None]:
!pip install typeguard
!git clone https://github.com/NVlabs/stylegan2.git
%cd stylegan2
!git clone https://github.com/kylemcdonald/python-utils.git utils

In [None]:
import tensorflow as tf
import os
import pickle
import numpy as np
import PIL.Image
import dnnlib
import dnnlib.tflib as tflib
# import config
import pretrained_networks

from utils.imutil import imshow, imresize
from utils.mosaic import make_mosaic

import cv2  
import time
tf_lpips_pkg = __import__("lpips-tensorflow.lpips_tf") 

tflib.init_tf()

# Load pre-trained network.
url = 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-horse-config-f.pkl' # 'gdrive:networks/stylegan2-ffhq-config-f.pkl'
_G, _D, Gs = pretrained_networks.load_networks(url)
img_size = 256 #horse: 256 #cat:256 #car:512  # face:1024 

In [None]:
Nsliders = 3
dim_z = Gs.input_shape[1]

z = tf.placeholder(tf.float32, shape=(None, dim_z))
target = tf.placeholder(tf.float32, shape=(None, img_size, img_size, Nsliders))

mask = tf.placeholder(tf.float32, shape=(None, img_size, img_size, Nsliders))

alpha2 = tf.placeholder(tf.float32, shape=None)

w2 = tf.Variable(np.random.normal(0.0, 0.1, [1, 14, z.shape[1]]), name='walk_intermed', dtype=np.float32)

outputs_orig = tf.transpose(Gs.get_output_for(z, None, is_validation=True, 
                                                    randomize_noise=True), [0, 2, 3, 1])

out_dlatents = Gs.components.mapping.get_output_for(z, None)

out_dlatents_new = out_dlatents+alpha2*w2
    
transformed_output = tf.transpose(Gs.components.synthesis.get_output_for(out_dlatents_new, is_validation=True, 
                                                    randomize_noise=True), [0, 2, 3, 1])

## L_2 loss (if you want to try)
loss = tf.losses.compute_weighted_loss(tf.square(transformed_output-target), weights=mask)

# Lpips loss
loss_lpips = tf.reduce_mean(tf_lpips_pkg.lpips_tf.lpips(mask*transformed_output, mask*target, 
                                                  model='net-lin', net='alex'))

In [None]:
def initialize_uninitialized(sess):
    global_vars          = tf.global_variables()
    is_not_initialized   = sess.run([tf.is_variable_initialized(var) for var in global_vars])
    not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]

    print([str(i.name) for i in not_initialized_vars]) # only for testing
    if len(not_initialized_vars):
        sess.run(tf.variables_initializer(not_initialized_vars))
        return not_initialized_vars

In [None]:
style_sess = tf.get_default_session()
not_initialized_vars = initialize_uninitialized(style_sess)

lr = 1e-04
train_step = tf.train.AdamOptimizer(lr).minimize(loss_lpips, var_list=not_initialized_vars, 
                                                 name='AdamOpter')
# this time init Adam's vars:
not_initialized_vars = initialize_uninitialized(style_sess)

In [None]:
def get_target_np(outputs_zs, alpha, show_img=False, show_mask=False):
    
    mask_out = np.ones(outputs_zs.shape)
 
    if alpha == 0:
        return outputs_zs, mask_out
    
    M = np.float32([[1,0,alpha],[0,1,0]])
    target_fn = np.zeros(outputs_zs.shape)
    
    for i in range(outputs_zs.shape[0]):
        target_fn[i,:,:,:] = cv2.warpAffine(outputs_zs[i,:,:,:], M, (img_size, img_size))
        mask_out[i,:,:,:] = cv2.warpAffine(mask_out[i,:,:,:], M, (img_size, img_size))

    mask_out[np.nonzero(mask_out)] = 1.
    assert(np.setdiff1d(mask_out, [0., 1.]).size == 0)
        
    if show_img:
        img_show_size = 128
        print('Target image:')
        # just for showing, i need ot scale it so for NOW i am using the original scaling in tflib.convert_images_to_uint8
        im = style_sess.run(tflib.convert_images_to_uint8(tf.convert_to_tensor(target_fn, dtype=tf.float32)))
        
        for b in range(outputs_zs.shape[0]):
            images_resized = np.array(PIL.Image.fromarray(im[b,:,:,:]).resize((img_show_size, img_show_size)))
            imshow(images_resized)
            
    if show_mask:
        img_show_size = 128
        print('Target mask:')
        for b in range(outputs_zs.shape[0]):
            mask_tmp = (mask_out[b,:,:,:]*255).astype(np.uint8)
            images_resized = np.array(PIL.Image.fromarray(mask_tmp).resize((img_show_size, img_show_size)))
            imshow(images_resized)

    return target_fn, mask_out

In [None]:
import os
output_dir = '../shiftx_intermed_lpips_git_linear_horse_40k_{}'.format(lr)
os.makedirs(os.path.join(output_dir, 'images'), exist_ok=True)
os.makedirs(os.path.join(output_dir, 'output'), exist_ok=True)

In [None]:
saver = tf.train.Saver(tf.trainable_variables(scope=None))
# saver.restore(tf.get_default_session(), "./shiftx_intermed_lpips_git_linear_horse_40k_0.0001/model_5000_final.ckpt")

In [None]:
# This can be train.py

import logging
import sys
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(threadName)-12.12s] [%(levelname)-5.5s]  %(message)s",
    handlers=[
        logging.FileHandler("{0}/{1}.log".format(output_dir, 'train')),
        logging.StreamHandler(sys.stdout)
    ])
logger = logging.getLogger()


alpha_list = []
loss_vals = []

# train
def train(saver):
    num_samples=40000
    random_seed = 1
    rnd = np.random.RandomState(random_seed)
    zs = rnd.randn(num_samples, dim_z)

    Loss_sum = 0
    Loss_sum_iter = 0
    n_epoch = 1
    optim_iter = 0
    batch_size = 4
    alpha_max = 100
    alpha_range = np.arange(5, alpha_max+5, 5)

    for epoch in range(n_epoch):
        for batch_start in range(0, num_samples, batch_size):
            start_time = time.time()

            alpha_val = np.random.choice(alpha_range)
            ## lets flip a coin
            coin = np.random.uniform(0, 1)
            if coin <= 0.5:
                alpha_val = -alpha_val
                
            s = slice(batch_start, min(num_samples, batch_start + batch_size))

            feed_dict_out = {z: zs[s]}
            out_zs = style_sess.run(outputs_orig, feed_dict_out) 
            target_fn, mask_out = get_target_np(out_zs, alpha_val)#, show_img=True, show_mask=True)
            
            feed_dict = {z: zs[s], alpha2: alpha_val/alpha_max, target: target_fn, mask: mask_out}
            
            curr_loss, _  = style_sess.run([loss_lpips, train_step], feed_dict=feed_dict)
            Loss_sum = Loss_sum + curr_loss
            Loss_sum_iter = Loss_sum_iter + curr_loss
            
            elapsed_time = time.time() - start_time

            logger.info('T, epc, bst, lss, a: {}, {}, {}, {}, {}'.format(elapsed_time, epoch, batch_start, curr_loss, alpha_val))

            alpha_list.append(alpha_val)

            if (optim_iter % 2500 == 0) and (optim_iter > 0):
                saver.save(style_sess, '{}/{}/model_{}.ckpt'.format(output_dir, 'output', optim_iter*batch_size), write_meta_graph=False, write_state=False)
            
            if (optim_iter % 100 == 0) and (optim_iter > 0):
                loss_vals.append(Loss_sum_iter/(100*batch_size))
                Loss_sum_iter = 0
                print('Loss:', loss_vals)

            optim_iter = optim_iter+1
            
    if optim_iter > 0:
        loss_vals.append(Loss_sum_iter/(100*batch_size))
        print('average loss with this metric: ', Loss_sum/(optim_iter*batch_size))
    saver.save(style_sess, '{}/{}/model_{}.ckpt'.format(output_dir, 'output', optim_iter*batch_size), write_meta_graph=False, write_state=False)
    
    

In [None]:
train(saver)

In [None]:
# rescale the stylegan output range ([-1, 1]) to uint8 range [0, 255]
float_im = tf.placeholder(tf.float32, outputs_orig.shape)
uint8_im = tflib.convert_images_to_uint8(tf.convert_to_tensor(float_im, dtype=tf.float32))

In [None]:
# define some showing utils
import io
import IPython.display

def imgrid(imarray, cols=5, pad=1):
    if imarray.dtype != np.uint8:
        raise ValueError('imgrid input imarray must be uint8')
    pad = int(pad)
    assert pad >= 0
    cols = int(cols)
    assert cols >= 1
    N, H, W, C = imarray.shape
    rows = int(np.ceil(N / float(cols)))
    batch_pad = rows * cols - N
    assert batch_pad >= 0
    post_pad = [batch_pad, pad, pad, 0]
    pad_arg = [[0, p] for p in post_pad]
    imarray = np.pad(imarray, pad_arg, 'constant', constant_values=255)
    H += pad
    W += pad
    grid = (imarray
            .reshape(rows, cols, H, W, C)
            .transpose(0, 2, 1, 3, 4)
            .reshape(rows*H, cols*W, C))
    if pad:
        grid = grid[:-pad, :-pad]
    return grid

def imshow(a, format='png', jpeg_fallback=True, filename=None):
    a = np.asarray(a, dtype=np.uint8)
    str_file = io.BytesIO()
    PIL.Image.fromarray(a).save(str_file, format)
    img = PIL.Image.fromarray(a)
    im_data = str_file.getvalue()
    try:
        disp = IPython.display.display(IPython.display.Image(im_data))
        if filename:
            size = (a.shape[1]//2, a.shape[0]//2)
            im = PIL.Image.fromarray(a)
            im.thumbnail(size,PIL.Image.ANTIALIAS)
            im.save('{}.{}'.format(filename, format))
    except IOError:
        if jpeg_fallback and format != 'jpeg':
            print ('Warning: image was too large to display in '
                   'format "{}"; trying jpeg instead.'.format(format))
            return imshow(a, format='jpeg')
        else:
            raise
    return disp, img

In [None]:
# test

num_samples=6
batch_size =1
alpha_max = 100
a = np.arange(-alpha_max,alpha_max+20, 20)

random_seed = 6
rnd = np.random.RandomState(random_seed)
zs = rnd.randn(num_samples, dim_z)

for batch_num, batch_start in enumerate(range(0, num_samples, batch_size)):
    
    ims = []
    targets = []
    
    s = slice(batch_start, min(num_samples, batch_start + batch_size))

    input_test = {z:zs[s]}

    out_input_test = style_sess.run(outputs_orig, input_test)

    for i in range(a.shape[0]):
        # print(i)
        alpha_val = a[i]
        target_fn,_ = get_target_np(out_input_test, alpha_val) #, show_img=True, show_mask=True)
        best_inputs = {z: zs[s], alpha2: alpha_val/alpha_max}
#         best_inputs = {z: zs[s], alpha2: alpha_val/10} 
        best_im_out =  style_sess.run(transformed_output, best_inputs)
        
        # rescale
        best_im_out = style_sess.run(uint8_im, {float_im: best_im_out})
        target_fn = style_sess.run(uint8_im, {float_im: target_fn})
        
        ims.append(best_im_out)
        targets.append(target_fn)
    im_stack = np.concatenate(targets + ims).astype(np.uint8)
    imshow(imgrid(im_stack, cols = len(a)))

In [None]:
import matplotlib.pyplot as plt
loss_vals_x = np.arange(400, 40400, 400)
plt.plot(loss_vals_x, loss_vals)
plt.xlabel('num samples, lr{}'.format(lr))
plt.ylabel('Lpips')
plt.show()

In [None]:
## You can make a video
import cv2
ims_slerp = np.asarray(ims)
img_show_size = 256
fps = 10
video_name = 'stylegan2_horse_shiftx.mp4'

height = width = 256
video = cv2.VideoWriter(video_name, cv2.VideoWriter_fourcc(*'MP4V'), fps=fps, frameSize=(width,height))
for iter in range(0,ims_slerp.shape[0]):
    fimg = np.squeeze(ims_slerp[iter,:,:,:])
    fimg = cv2.resize(fimg, (256, 256))
    a,b = imshow(fimg)
    video.write(cv2.cvtColor(np.array(b), cv2.COLOR_BGR2RGB))
video.release()