# Real Time Style Transfer

based on paper by [Justin Johnson, et al](https://cs.stanford.edu/people/jcjohns/eccv16/)

In [1]:
import tensorflow as tf
import tensorflow.keras

from tensorflow.keras.layers import Conv2D, Conv2DTranspose, Activation, add, BatchNormalization

tf.enable_eager_execution()

In [2]:
# Helpers
def reflection_padding():
    def f(inputs):
        return tf.pad(x, [[0, 0], [40, 40], [40, 40], [0, 0]], "REFLECT")
    return f

def conv_layer(n_channels, kernel_size, strides, padding="same"):
    return Conv2D(filters=n_channels, kernel_size=kernel_size, strides=strides, padding=padding)

def conv_transpose_layer(n_channels, kernel_size, strides, padding="same"):
    return Conv2DTranspose(n_channels, kernel_size=kernel_size, strides=strides, padding=padding)

def residual_block(n_channels, kernel_size=3, strides=1, padding='valid'):
    def f(inputs):
        conv_1 = Conv2D(filters=n_channels, kernel_size=kernel_size, 
                      strides=strides, padding=padding)(inputs)
        bn_1 = BatchNormalization()(conv_1)
        relu_1 = Activation("relu")(bn_1)
        conv_2 = Conv2D(filters=n_channels, kernel_size=kernel_size, 
                      strides=strides, padding=padding)(relu_1)
        bn_2 = BatchNormalization()(conv_2)
        return add([bn_2, inputs])

    return f

In [7]:
class StyleTransferModel(tf.keras.Model):
    def __init__(self):
        super(StyleTransferModel, self).__init__(name='style_transfer_model')
        
        # Layers
        self.pad = reflection_padding()
        self.conv_1 = conv_layer(32, 9, 1)
        self.conv_2 = conv_layer(64, 3, 2)
        self.conv_3 = conv_layer(128, 3, 2)
        self.res_1 = residual_block(128, 3, 1)
        self.res_2 = residual_block(128, 3, 1)
        self.res_3 = residual_block(128, 3, 1)
        self.res_4 = residual_block(128, 3, 1)
        self.res_5 = residual_block(128, 3, 1)
        self.conv_4 = conv_transpose_layer(64, 3, 2)
        self.conv_5 = conv_transpose_layer(32, 3, 2)
        self.conv_6 = conv_layer(3, 9, 1)


        
    def call(self, inputs):
        print(tf.shape(inputs))
        inputs = self.pad(inputs)
        print(tf.shape(inputs))
        inputs = self.conv_1(inputs)
        print(tf.shape(inputs))
        inputs = self.conv_2(inputs)
        print(tf.shape(inputs))
        inputs = self.conv_3(inputs)
        print(tf.shape(inputs))
        inputs = self.res_1(inputs)
        print(tf.shape(inputs))
        inputs = self.res_2(inputs)
        print(tf.shape(inputs))
        inputs = self.res_3(inputs)
        print(tf.shape(inputs))
        inputs = self.res_4(inputs)
        print(tf.shape(inputs))
        inputs = self.res_5(inputs)
        print(tf.shape(inputs))
        inputs = self.conv_4(inputs)
        print(tf.shape(inputs))
        inputs = self.conv_5(inputs)
        print(tf.shape(inputs))
        inputs = self.conv_6(inputs)
        print(tf.shape(inputs))
        return inputs

In [8]:
net = StyleTransferModel()

In [9]:
x = tf.random.uniform((1, 256, 256, 3))
x

<tf.Tensor: id=1068, shape=(1, 256, 256, 3), dtype=float32, numpy=
array([[[[0.6179217 , 0.08211792, 0.46838963],
         [0.8542968 , 0.5602442 , 0.68404496],
         [0.9978471 , 0.13433456, 0.10006142],
         ...,
         [0.9634851 , 0.48664296, 0.649788  ],
         [0.7555078 , 0.738752  , 0.9152154 ],
         [0.8878497 , 0.00801635, 0.7048563 ]],

        [[0.5034343 , 0.33736682, 0.5857403 ],
         [0.33241034, 0.07088721, 0.12559164],
         [0.9119524 , 0.24792302, 0.6528914 ],
         ...,
         [0.79800034, 0.37630773, 0.11333871],
         [0.699348  , 0.33340025, 0.04093003],
         [0.79738736, 0.28225923, 0.85163367]],

        [[0.6062561 , 0.97209156, 0.4544047 ],
         [0.9099479 , 0.84586906, 0.90545785],
         [0.8695953 , 0.34119904, 0.6827289 ],
         ...,
         [0.3157376 , 0.6252551 , 0.44137037],
         [0.27747798, 0.3505658 , 0.10303998],
         [0.18329275, 0.79868186, 0.79860103]],

        ...,

        [[0.83798134, 0.0

In [10]:
net(x)

tf.Tensor([  1 256 256   3], shape=(4,), dtype=int32)
tf.Tensor([  1 336 336   3], shape=(4,), dtype=int32)
tf.Tensor([  1 336 336  32], shape=(4,), dtype=int32)
tf.Tensor([  1 168 168  64], shape=(4,), dtype=int32)
tf.Tensor([  1  84  84 128], shape=(4,), dtype=int32)


ValueError: Operands could not be broadcast together with shapes (21, 21, 128) (84, 84, 128)