In [1]:
import tensorflow as tf
import numpy as np

In [2]:
# parmeters
batch, height, width, feature = 2, 3, 3, 2
test_landmarks = 2*2
test_style_images_num = 5
test_style_n_best = 3

In [3]:
def filter_by_landmarks(x_landmarks, style_landmarks, style_images, style_images_num=60, style_n_best=16, batch=batch, landmarks_num=68*2):
    """
    param x_landmarks: 트레이닝 이미지의 랜드마크, shape: (batch, 68*2)
    param style_landmarks: 스타일 이미지들(Y)의 랜드마크, shape: (60, 68*2)
    param style_images: 스타일 이미지들(Y), shape: (60, height, width, 3)
    return style_best: batch 당 customized set of style images, shape: (batch, style_n_best, height, width, 3)
    ps) style_images_num: 논문에서 Y는 60개의 스타일 이미지의 집합, style_n_best: 16개만 쓰는게 좋다고 한다.
    """
    x_tile = tf.tile(x_landmarks, [1, style_images_num])
    x_tile = tf.reshape(x_tile, [batch, style_images_num, landmarks_num])
    
    style_tile = tf.tile(style_landmarks, [batch, 1])
    style_tile = tf.reshape(style_tile, [batch, style_images_num, landmarks_num])
    
    mse = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(x_tile, style_tile)), 2))
    mse = tf.negative(mse) # 작은 값을 가져와야 하니깐!
    sort = tf.nn.top_k(mse, style_n_best)
    
    style_best = tf.gather(style_images, sort.indices)
    return style_best

In [8]:
def compute_style_loss(vgg_generated_images, vgg_style_best, style_n_best=16, batch=batch):
    print(vgg_generated_images.shape)
    print(vgg_style_best.shape)
    vgg_generated_images_tile = tf.tile(vgg_generated_images, [1, style_n_best, 1, 1])
    vgg_generated_images_tile = tf.reshape(vgg_generated_images_tile, [batch, style_n_best, height, width, feature])
    
    normalize_g = tf.nn.l2_normalize(vgg_generated_images_tile, 4)
    normalize_s = tf.nn.l2_normalize(vgg_style_best, 4)

    cos_distance = tf.squeeze((1 - tf.reduce_sum(tf.multiply(normalize_g, normalize_s), 4)))
    cos_distance = tf.reduce_min(cos_distance, 1)
    style_loss = tf.reduce_mean(cos_distance)
    return style_loss

In [9]:
with tf.Session() as sess:    
    x_landmarks = tf.placeholder(tf.float32, shape=(batch, test_landmarks))
    style_landmarks = tf.placeholder(tf.float32, shape=(test_style_images_num, test_landmarks))
    style_images = tf.placeholder(tf.float32, shape=(test_style_images_num, height, width, feature))
    
    init_x_landmarks = np.array(
        [
            [1,2,3,4],
            [9,8,7,6]
        ])
    
    init_style_landmarks = np.array(
        [
            [11, 12, 13, 14],
            [21, 22, 23, 24],
            [31, 32, 33, 34],
            [41, 42, 43, 44],
            [51, 52, 53, 54]
        ])
    
    init_style_images = np.array(
        [
            [[[201, 201], [202, 202], [203, 203]], [[204, 204], [205, 205], [206, 206]], [[207, 207], [208, 208], [209, 209]]],
            [[[301, 301], [302, 302], [303, 303]], [[304, 304], [305, 305], [306, 306]], [[307, 307], [308, 308], [309, 309]]],
            [[[401, 401], [402, 402], [403, 403]], [[404, 404], [405, 405], [406, 406]], [[407, 407], [408, 408], [409, 409]]],
            [[[501, 501], [502, 502], [503, 503]], [[504, 504], [505, 505], [506, 506]], [[507, 507], [508, 508], [509, 509]]],
            [[[601, 601], [602, 602], [603, 603]], [[604, 604], [605, 605], [606, 606]], [[607, 607], [608, 608], [609, 609]]]
        ])
            
    style_best = filter_by_landmarks(x_landmarks, style_landmarks, style_images, style_images_num=test_style_images_num, style_n_best=test_style_n_best, batch=batch, landmarks_num=test_landmarks)
    
    # vgg phase
    # vgg_generated_images = vgg(generated_images) 
    # vgg_style_best = vgg(style_best)
    vgg_style_best = style_best
    
    vgg_generated_images = tf.placeholder(tf.float32, shape=(batch, height, width, feature))
    
    init_vgg_generated_images = np.array(
        [
            [[[5, 113], [2, 2], [3, 3]], [[4, 4], [5, 5], [6, 6]], [[7, 7], [8, 8], [9, 9]]],
            [[[11, 11], [12, 12], [13, 13]], [[14, 14], [15, 15], [16, 16]], [[17, 17], [18, 18], [19, 19]]]
        ])
    
    style_loss = compute_style_loss(vgg_generated_images, vgg_style_best, test_style_n_best)
    
    feed_dict = {
        x_landmarks: init_x_landmarks,
        style_landmarks: init_style_landmarks,
        style_images: init_style_images,
        vgg_generated_images: init_vgg_generated_images
    }

    out = sess.run([style_loss], feed_dict=feed_dict)

(2, 3, 3, 2)
(2, 3, 3, 3, 2)


In [6]:
np.asarray(out).shape

(1,)

In [7]:
out

[0.014573809]

In [13]:
np.asarray(out)

array([[[[[[ 201.,  201.],
           [ 202.,  202.],
           [ 203.,  203.]],

          [[ 204.,  204.],
           [ 205.,  205.],
           [ 206.,  206.]],

          [[ 207.,  207.],
           [ 208.,  208.],
           [ 209.,  209.]]],


         [[[ 301.,  301.],
           [ 302.,  302.],
           [ 303.,  303.]],

          [[ 304.,  304.],
           [ 305.,  305.],
           [ 306.,  306.]],

          [[ 307.,  307.],
           [ 308.,  308.],
           [ 309.,  309.]]],


         [[[ 401.,  401.],
           [ 402.,  402.],
           [ 403.,  403.]],

          [[ 404.,  404.],
           [ 405.,  405.],
           [ 406.,  406.]],

          [[ 407.,  407.],
           [ 408.,  408.],
           [ 409.,  409.]]],


         [[[ 501.,  501.],
           [ 502.,  502.],
           [ 503.,  503.]],

          [[ 504.,  504.],
           [ 505.,  505.],
           [ 506.,  506.]],

          [[ 507.,  507.],
           [ 508.,  508.],
           [ 509.,  509.]]]