In [1]:
import numpy as np
import tensorflow as tf

from tensorflow_vgg import vgg19
from tensorflow_vgg import utils

import os
from skimage import io, transform
import matplotlib.pyplot as plt

#Ignore warnings
import warnings
warnings.filterwarnings('ignore')

plt.ion() # interactive mode

In [2]:
# parmeters
batch, height, width, channel = 2, 3, 3, 3
test_landmarks = 2*2
test_style_images_num = 5
test_style_n_best = 3

In [3]:
graph = tf.Graph()
with graph.as_default():
    sess = tf.Session()
    
    # set vgg => will be normalized?
    vgg = vgg19.Vgg19()

/Users/user/Desktop/FastFaceSwapStyleLoss/tensorflow_vgg/vgg19.npy
npy file loaded


In [4]:
def filter_by_landmarks(x_landmarks, style_landmarks, style_images, style_images_num=60, style_n_best=16, batch=batch, landmarks_num=68*2):
    """
    param x_landmarks: 트레이닝 이미지의 랜드마크, shape: (batch, 68*2)
    param style_landmarks: 스타일 이미지들(Y)의 랜드마크, shape: (60, 68*2)
    param style_images: 스타일 이미지들(Y), shape: (60, height, width, 3)
    return style_best: batch 당 customized set of style images, shape: (batch, style_n_best, height, width, 3)
    ps) style_images_num: 논문에서 Y는 60개의 스타일 이미지의 집합, style_n_best: 16개만 쓰는게 좋다고 한다.
    """
    x_tile = tf.tile(x_landmarks, [1, style_images_num])
    x_tile = tf.reshape(x_tile, [batch, style_images_num, landmarks_num])
    
    style_tile = tf.tile(style_landmarks, [batch, 1])
    style_tile = tf.reshape(style_tile, [batch, style_images_num, landmarks_num])
    
    mse = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(x_tile, style_tile)), 2))
    mse = tf.negative(mse) # 작은 값을 가져와야 하니깐!
    sort = tf.nn.top_k(mse, style_n_best)
    
    style_best = tf.gather(style_images, sort.indices)
    return style_best

In [5]:
def compute_style_loss(vgg_generated_images, vgg_style_best, style_n_best=16, batch=batch):
    """
    param vgg_generated_images: 생성된 이미지를 vgg를 통과시켜 넣은 것, shape: (2, 1, 1, 256)
    param vgg_style_best: 각 인풋 이미지에 대해 골라진 베스트 style image들을 vgg를 통과시킨 것, shape: (2, 3, 1, 1, 256)
    """

    _batch, _height, _width, _feature = vgg_generated_images.get_shape().as_list()
    vgg_generated_images_tile = tf.tile(vgg_generated_images, [1, style_n_best, 1, 1]) # (2, 3, 1, 256)
    
    vgg_generated_images_tile = tf.reshape(vgg_generated_images_tile, [batch, style_n_best, _height, _width, _feature])
    
    normalize_g = tf.nn.l2_normalize(vgg_generated_images_tile, 4)
    normalize_s = tf.nn.l2_normalize(vgg_style_best, 4)

    cos_distance = tf.squeeze((1 - tf.reduce_sum(tf.multiply(normalize_g, normalize_s), 4)))
    cos_distance = tf.reduce_min(cos_distance, 1)
    style_loss = tf.reduce_mean(cos_distance)
    return style_loss

In [6]:
with tf.Session() as sess:    
    x_landmarks = tf.placeholder(tf.float32, shape=(batch, test_landmarks))
    style_landmarks = tf.placeholder(tf.float32, shape=(test_style_images_num, test_landmarks))
    style_images = tf.placeholder(tf.float32, shape=(test_style_images_num, height, width, channel))
    
    init_x_landmarks = np.array(
        [
            [1,2,3,4],
            [9,8,7,6]
        ])
    
    init_style_landmarks = np.array(
        [
            [11, 12, 13, 14],
            [21, 22, 23, 24],
            [31, 32, 33, 34],
            [41, 42, 43, 44],
            [51, 52, 53, 54]
        ])
    
    init_style_images = np.array(
        [
            [[[201, 201, 201], [202, 202, 201], [203, 203, 201]], [[204, 204, 201], [205, 205, 201], [206, 206, 201]], [[207, 207, 201], [208, 208, 201], [209, 209, 201]]],
            [[[301, 301, 201], [302, 302, 201], [303, 303, 201]], [[304, 304, 201], [305, 305, 201], [306, 306, 201]], [[307, 307, 201], [308, 308, 201], [309, 309, 201]]],
            [[[401, 401, 201], [402, 402, 201], [403, 403, 201]], [[404, 404, 201], [405, 405, 201], [406, 406, 201]], [[407, 407, 201], [408, 408, 201], [409, 409, 201]]],
            [[[501, 501, 201], [502, 502, 201], [503, 503, 201]], [[504, 504, 201], [505, 505, 201], [506, 506, 201]], [[507, 507, 201], [508, 508, 201], [509, 509, 201]]],
            [[[601, 601, 201], [602, 602, 201], [603, 603, 201]], [[604, 604, 201], [605, 605, 201], [606, 606, 201]], [[607, 607, 201], [608, 608, 201], [609, 609, 201]]]
        ])
            
    style_best = filter_by_landmarks(x_landmarks, style_landmarks, style_images, style_images_num=test_style_images_num, style_n_best=test_style_n_best, batch=batch, landmarks_num=test_landmarks)

    generated_images = tf.placeholder(tf.float32, shape=(batch, height, width, channel))

    init_generated_images = np.array(
        [
            [[[5, 113, 199], [2, 2, 199], [3, 3, 199]], [[4, 4, 199], [5, 5, 199], [6, 6, 199]], [[7, 7, 199], [8, 8, 199], [9, 9, 199]]],
            [[[11, 11, 199], [12, 12, 199], [13, 13, 199]], [[14, 14, 199], [15, 15, 199], [16, 16, 199]], [[17, 17, 199], [18, 18, 199], [19, 19, 199]]]
        ])
    
    # vgg phase
    vgg.build(generated_images)
    relu3_1_generated_image = vgg.conv3_1 # (1, 40, 40, 256)
    relu4_1_generated_images = vgg.conv4_1

    relu3_1_vgg_style_best = []
    relu4_1_vgg_style_best = []
    for style_layer_per_batch in tf.unstack(style_best):
        vgg.build(style_layer_per_batch)
        relu3_1_vgg_style_best.append(vgg.conv3_1)
        relu4_1_vgg_style_best.append(vgg.conv4_1)
    
    relu3_1_vgg_style_best = tf.stack(relu3_1_vgg_style_best)
    relu4_1_vgg_style_best = tf.stack(relu4_1_vgg_style_best)
        
    relu3_1_style_loss = compute_style_loss(relu3_1_generated_image, relu3_1_vgg_style_best, style_n_best=test_style_n_best)
    relu4_1_style_loss = compute_style_loss(relu4_1_generated_images, relu4_1_vgg_style_best, style_n_best=test_style_n_best)
    style_loss = relu3_1_style_loss + relu4_1_style_loss
    
    feed_dict = {
        x_landmarks: init_x_landmarks,
        style_landmarks: init_style_landmarks,
        style_images: init_style_images,
        generated_images: init_generated_images
    }

    out = sess.run([style_loss], feed_dict=feed_dict)

build model started
build model finished: 0s
build model started
build model finished: 0s
build model started
build model finished: 0s


In [7]:
out

[0.017204434]