In [1]:
import tensorflow as tf
from tensorflow.keras import layers,Model,applications
import cv2
from matplotlib import pyplot as plt
%matplotlib inline
import os

In [2]:
from IPython.display import Latex
from IPython.display import Math

In [3]:
from tensorflow.keras.applications.vgg19 import VGG19,preprocess_input
import numpy as np

We start by building the WCT function which is described in https://papers.nips.cc/paper/6642-universal-style-transfer-via-feature-transforms.pdf:
We need a function that would create feature maps at each relevant layer we choose to compute the features for, and perform PCA via single value decomposition for the style and content feature maps in-order to match the covariance of the of the 2 feature maps. This is the key to fitting the style of the content image to that of the style image.

In [4]:
print('Content feature map:')
Math(r'f_c  \in R^{C\times H_cW_c}')

Content feature map:


<IPython.core.display.Math object>

In [5]:
print('Style feature map:')
Math(r'f_s  \in R^{C\times H_sW_s}')

Style feature map:


<IPython.core.display.Math object>

below are the expressions for the linear transforms for the feature maps

In [6]:
%%latex
\begin{align}
\hat{f}_c = E_cD^{-\frac{1}{2}}_cE^\top_cf_c\\
\hat{f}_{cs} = E_sD^{\frac{1}{2}}_sE^\top_s\hat{f}_c\\
\end{align}

<IPython.core.display.Latex object>

Ec is the corresponding orthogonal matrix of eigenvectors i.e. auto-covariance matrix, satisfying:

In [7]:
Math(r'f_c f^\top_c = E_cD_cE^\top_c')

<IPython.core.display.Math object>

The auto-covariance matrix can be factored using SVD into a matrix of eigen vectors,diagonal matrix of eigen values, and
the transpose of the eigen vectors.
tf.svd returns the eigen values in decending order or magnitude, we can use svd to perform PCA and remove eigen values which are too small. This will filter out noise within the feature space and performs the "whitening" where the low level features in the pixel space will be removed leaving only the higher level features, allowing feature space to be "colored in" later.

In [8]:
def WCT(content_features, style_features, alpha, eps = 1e-8):
    #assuming image = dim(1,W,H,C) convert to (C,W,H)
    content_transpose = tf.transpose(tf.squeeze(content_features),(2,0,1))
    style_transpose = tf.transpose(tf.squeeze(style_features),(2,0,1))
    
    #get the dimensions C,W,H
    Cc,Hc,Wc = tf.unstack(tf.shape(content_transpose))
    Cs,Hs,Ws = tf.unstack(tf.shape(style_transpose))
    
    # reshape to C,W*H
    content_feature_map = tf.reshape(content_transpose,(Cc,Hc*Wc))
    style_feature_map = tf.reshape(style_transpose,(Cs,Hs*Ws))
    
    #take the mean with respect to each channel and center the feature maps since we only care about
    # the second order statistics
    mc = tf.reduce_mean(content_feature_map,axis =1, keep_dims = True)
    ms = tf.reduce_mean(style_feature_map,axis =1, keep_dims = True)
    
    fc = content_feature_map-mc
    fs = style_feature_map-ms
    
    #auto-covariance matrices for style and content
    fcfcT = tf.matmul(fc,fc,transpose_b=True)/(tf.cast(Hc*Wc, tf.float32) - 1.) + tf.eye(Cc)*eps
    fsfsT = tf.matmul(fs,fs,transpose_b=True)/(tf.cast(Hs*Ws, tf.float32) - 1.) + tf.eye(Cs)*eps
    
    #find eigen values/vectors via SVD and perform PCA to filter features/channels that have small eigen values
    # this removes noise which "whitens" the feature space
    Dc,Ec,EcT = tf.linalg.svd(fcfcT)
    Ds,Es,EsT = tf.linalg.svd(fsfsT)
    
    k_c = tf.reduce_sum(tf.cast(tf.greater(Dc, 1e-5), tf.int32))
    k_s = tf.reduce_sum(tf.cast(tf.greater(Ds, 1e-5), tf.int32))    
    
    sqrt_Dc = tf.diag(tf.pow(Dc[:k_c],-0.5))
    sqrt_Ds = tf.diag(tf.pow(Ds[:k_s],0.5))
    
    #compute fcs_hat and recenter, computing fcs_hat is the inverse operation of the whitening and ensures that the 
    #generated image has a feature map that matches the covariance of the style image
    fc_hat = (tf.matmul(tf.matmul(Ec[:,:k_c],sqrt_Dc),Ec[:,:k_c],transpose_b=True),fc)
    fcs_hat = (tf.matmul(tf.matmul(Es[:,:k_s],sqrt_Ds),Es[:,:k_s],transpose_b=True),fc_hat)
    fcs_hat += ms
    
    #create a blended feature map weighted by alpha and reshape it back into dims = (1,W,H,C) to be decoded
    blended = alpha*fcs_hat+(1-alpha)*(fc+mc)
    blended = tf.reshape(blended,(Cc,Hc,Wc))
    blended = tf.expand_dims(tf.transpose(blended,(1,2,0)),0)
    return blended

In [9]:
class Encoder(Model):
    def __init__(self):
        super(Encoder,self).__init__()
        self.target_layers = ['block1_conv1','block2_conv1','block3_conv1','block4_conv1','block5_conv1']
        self.vgg = VGG19(include_top = False,weights = 'imagenet')
        self.style_model = Model([self.vgg.input],[self.vgg.get_layer(name).output for name in self.target_layers])
    
    def call(self,inputs,target=None):
        outputs = self.style_model(inputs)
        if target!=None:
            return outputs[target]
        else:
            return outputs

In [10]:
encoder= Encoder()

In [11]:
def Conv2DReflect(*args, **kwargs):
    return layers.Lambda(lambda x: layers.Conv2D(*args, **kwargs)(tf.pad(x, [[0, 0], [1, 1], [1, 1], [0, 0]],mode='REFLECT')))

In [12]:
encoder.vgg.summary()

Model: "vgg19"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, None, None, 3)]   0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, None, None, 64)    1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, None, None, 64)    36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, None, None, 64)    0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, None, None, 128)   73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, None, None, 128)   147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, None, None, 128)   0     

Here we create 5 seperate decoders to perform the multilayered style transfer. Each decoder is identical to the inverse of the 
VGG19 encoder layers for the target block.

In [13]:
class Decoder5(Model):
    def __init__(self):
        super(Decoder5,self).__init__()
        self.block5_conv1 = Conv2DReflect(512, kernel_size = (3,3), strides= 1, padding = 'valid',activation = 'relu')
        self.block5_upsample = layers.UpSampling2D()
        self.block5_layers = [Conv2DReflect(512, kernel_size = (3,3), strides= 1, padding = 'valid',activation = 'relu') for i in range(3)]
        self.block4_conv1 = Conv2DReflect(256, kernel_size = (3,3), strides= 1, padding = 'valid',activation = 'relu')
        self.block4_upsample = layers.UpSampling2D()
        self.block4_layers = [Conv2DReflect(256, kernel_size = (3,3), strides= 1, padding = 'valid',activation = 'relu') for i in range(3)]
        self.block3_conv1 = Conv2DReflect(128, kernel_size = (3,3), strides= 1, padding = 'valid',activation = 'relu')
        self.block3_upsample = layers.UpSampling2D()
        self.block3_conv2 = Conv2DReflect(128, kernel_size = (3,3), strides= 1, padding = 'valid',activation = 'relu')
        self.block2_conv1 = Conv2DReflect(64, kernel_size = (3,3), strides= 1, padding = 'valid',activation = 'relu')
        self.block2_upsample = layers.UpSampling2D()
        self.block1_conv1 = Conv2DReflect(64, kernel_size = (3,3), strides= 1, padding = 'valid',activation = 'relu')
        self.decoder_output = Conv2DReflect(3,kernel_size = (3,3),strides = 1,padding='valid')
        
    def call(self,inputs):
        x = self.block5_conv1(inputs)
        x = self.block5_upsample(x)
        for i in range(3):
            x = self.block5_layers[i](x)
        x = self.block4_conv1(x)
        x = self.block4_upsample(x)
        for i in range(3):
            x = self.block4_layers[i](x)
        x = self.block3_conv1(x)
        x = self.block3_upsample(x)
        x = self.block3_conv2(x)
        x = self.block2_conv1(x)
        x = self.block2_upsample(x)
        x = self.block1_conv1(x)
        x = self.decoder_output(x)
        
        return x

In [14]:
class Decoder4(Model):
    def __init__(self):
        super(Decoder4,self).__init__()
        self.block4_conv1 = Conv2DReflect(256, kernel_size = (3,3), strides= 1, padding = 'valid',activation = 'relu')
        self.block4_upsample = layers.UpSampling2D()
        self.block4_layers = [Conv2DReflect(256, kernel_size = (3,3), strides= 1, padding = 'valid',activation = 'relu') for i in range(3)]
        self.block3_conv1 = Conv2DReflect(128, kernel_size = (3,3), strides= 1, padding = 'valid',activation = 'relu')
        self.block3_upsample = layers.UpSampling2D()
        self.block3_conv2 = Conv2DReflect(128, kernel_size = (3,3), strides= 1, padding = 'valid',activation = 'relu')
        self.block2_conv1 = Conv2DReflect(64, kernel_size = (3,3), strides= 1, padding = 'valid',activation = 'relu')
        self.block2_upsample = layers.UpSampling2D()
        self.block1_conv1 = Conv2DReflect(64, kernel_size = (3,3), strides= 1, padding = 'valid',activation = 'relu')
        self.decoder_output = Conv2DReflect(3,kernel_size = (3,3),strides = 1,padding='valid')
        
    def call(self,inputs):
        x = self.block4_conv1(x)
        x = self.block4_upsample(inputs)
        for i in range(3):
            x = self.block4_layers[i](x)
        x = self.block3_conv1(x)
        x = self.block3_upsample(x)
        x = self.block3_conv2(x)
        x = self.block2_conv1(x)
        x = self.block2_upsample(x)
        x = self.block1_conv1(x)
        x = self.decoder_output(x)
        
        return x

In [15]:
class Decoder3(Model):
    def __init__(self):
        super(Decoder3,self).__init__()
        self.block3_conv1 = Conv2DReflect(128, kernel_size = (3,3), strides= 1, padding = 'valid',activation = 'relu')
        self.block3_upsample = layers.UpSampling2D()
        self.block3_conv2 = Conv2DReflect(128, kernel_size = (3,3), strides= 1, padding = 'valid',activation = 'relu')
        self.block2_conv1 = Conv2DReflect(64, kernel_size = (3,3), strides= 1, padding = 'valid',activation = 'relu')
        self.block2_upsample = layers.UpSampling2D()
        self.block1_conv1 = Conv2DReflect(64, kernel_size = (3,3), strides= 1, padding = 'valid',activation = 'relu')
        self.decoder_output = Conv2DReflect(3,kernel_size = (3,3),strides = 1,padding='valid')

    def call(self,inputs):
        x = self.block3_conv1(inputs)
        x = self.block3_upsample(x)
        x = self.block3_conv2(x)
        x = self.block2_conv1(x)
        x = self.block2_upsample(x)
        x = self.block1_conv1(x)
        x = self.decoder_output(x)
        
        return x

In [16]:
class Decoder2(Model):
    def __init__(self):
        super(Decoder2,self).__init__()
        self.block2_conv1 = Conv2DReflect(64, kernel_size = (3,3), strides= 1, padding = 'valid',activation = 'relu')
        self.block2_upsample = layers.UpSampling2D()
        self.block1_conv1 = Conv2DReflect(64, kernel_size = (3,3), strides= 1, padding = 'valid',activation = 'relu')
        self.decoder_output = Conv2DReflect(3,kernel_size = (3,3),strides = 1,padding='valid')
        
    def call(self,inputs):
        x = self.block2_conv1(inputs)
        x = self.block2_upsample(x)
        x = self.block1_conv1(x)
        x = self.decoder_output(x)
        
        return x

In [17]:
class Decoder1(Model):
    def __init__(self):
        super(Decoder1,self).__init__()
        self.block1_conv1 = Conv2DReflect(64, kernel_size = (3,3), strides= 1, padding = 'valid',activation = 'relu')
        self.decoder_output = Conv2DReflect(3,kernel_size = (3,3),strides = 1,padding='valid')

    def call(self,inputs):
        x = self.block1_conv1(inputs)
        x = self.decoder_output(x)
        return x

In [18]:
decoder1=Decoder1()
decoder2=Decoder2()
decoder3=Decoder3()
decoder4=Decoder4()
decoder5=Decoder5()

decoders = [decoder1,decoder2,decoder3,decoder4,decoder5]

In [20]:
optimizer = tf.keras.optimizers.Adam(1e-4)

In [21]:
@tf.function
def compute_loss(input_features,training_image,decoded_image,decoded_features):
    reconstruction_loss = tf.reduce_mean(tf.keras.losses.MSE(training_image,decoded_image))
    feature_loss = tf.reduce_mean(tf.keras.losses.MSE(input_features,decoded_features))
    total_variation = tf.reduce_mean(tf.image.total_variation(decoded_image))
    
    total_loss = reconstruction_loss+feature_loss+0.01*total_variation
    return total_loss

In [None]:
import pathlib
data_dir = tf.keras.utils.get_file(origin='http://images.cocodataset.org/zips/train2014.zip',
                                         fname='coco_dataset', extract=True)
data_dir = pathlib.Path(data_dir)

image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
batchsize = 32
image_height = 224
image_width = 224

In [None]:
train_images = image_generator.flow_from_directory(directory=os.path.dirname(data_dir),batch_size=batchsize,
                                                     shuffle=True,target_size=(image_height, image_width))
@tf.function
def train_decoders(training_image,target_layer):
    with tf.GradientTape() as tape:
        decoded_image=decoders[target_layer-1](training_image)
        input_features = encoder(training_image,target=(target_layer-1))
        decoded_features = encoder(decoded_image,target = (target_layer-1))
        loss = compute_loss(input_features,training_image,decoded_image,decoded_features)
        
    graidents = tape.gradient(loss,decoders[target_layer-1].trainable_variables)
    optimizer.apply_gradients(zip(graidents,decoders[target_layer-1].trainable_variables))

In [None]:
for target_layer in range(1,6):
    i=0
    for batch in train_images:
        i+=1
        if i==300:
            break
        elif i%100==0:
            print("decoder {} iteration {}".format(target_layer,i))
        train_decoders(batch[0],target_layer)
    tf.saved_model.save(decoders[target_layer-1], r"DecoderModels/decoder{}".format(target_layer))

In [0]:
def MultiStageStyleTransfer(encoder,decoders,content_image,style_image):
    for i in range(1,6):
        style_features = encoder(style_image,target = i-1)
        content_features = encoder(content_image,target=i-1)
        transformed_features = WCT(content_features,style_features,0.6)
        content_image = decoders[i-1]( transformed_features)
        content_image =tf.clip_by_value(content_image,clip_value_min = 0.0, clip_value_max = 1.0)
    return content_image

# New Section