# **Neural Style Transfer**

##  Preparation

### Imports

In [78]:
import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Dense, Flatten, Activation, Conv2D, MaxPooling2D, Dropout, Rescaling
from tensorflow.keras.models import Sequential, load_model, Model
import kagglehub
import glob

### Global Variables and Policy

In [79]:
content_layers = ['block5_conv2']
style_layers = ['block1_conv1', 'block2_conv1', 'block3_conv1', 'block4_conv1']

In [80]:
tf.keras.mixed_precision.set_global_policy("mixed_float16")

In [81]:
IMAGE_SIZE = 256
epochs = 10
batch_size = 2

In [82]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print("GPU memory growth configuration error:", e)

In [83]:
import gc
tf.keras.backend.clear_session()
gc.collect()

0

### Dataset

In [84]:
path_to_dataset = kagglehub.dataset_download("shaorrran/coco-wikiart-nst-dataset-512-100000")
path_to_dataset

'/home/oslyris/.cache/kagglehub/datasets/shaorrran/coco-wikiart-nst-dataset-512-100000/versions/1'

In [85]:
content_data = tf.keras.preprocessing.image_dataset_from_directory(directory=path_to_dataset+"/content/",
                                                                   labels=None,
                                                                   label_mode=None,
                                                                   color_mode="rgb",
                                                                   image_size=(IMAGE_SIZE, IMAGE_SIZE),
                                                                   shuffle=True,
                                                                   batch_size=batch_size)
style_data = tf.keras.preprocessing.image_dataset_from_directory(directory=path_to_dataset+"/style/",
                                                                 labels=None,
                                                                 label_mode=None,
                                                                 color_mode="rgb",
                                                                 image_size=(IMAGE_SIZE, IMAGE_SIZE),
                                                                 shuffle=True,
                                                                 batch_size=batch_size)

Found 49981 files.
Found 49981 files.


In [86]:
# normalization layer to rescale
normalization_layer = Rescaling(1/255.)

# combinbing and rescaling datasets with parallel processing and prefetching
dataset = tf.data.Dataset.zip((content_data, style_data))
dataset = dataset.map(lambda c, s: (normalization_layer(c), normalization_layer(s)), num_parallel_calls=tf.data.AUTOTUNE).prefetch(buffer_size=tf.data.AUTOTUNE)

## Custom Layers

### Adaptive Instance Normalization

In [None]:
class AdaptiveInstanceNormalization(tf.keras.layers.Layer):
    def __init__(self, epsilon=1e-5, **kwargs):
        super().__init__(**kwargs)
        self.epsilon = epsilon

    def call(self, inputs):
        content_features, style_features = inputs
        stylized_features = []

        content_features = tf.cast(content_features, tf.float32)
        
        for style_feature in style_features:
            style_feature = tf.cast(style_feature, tf.float32)
            
            target_height, target_width = tf.shape(style_feature)[1], tf.shape(style_feature)[2]
            
            resized_content = tf.image.resize(
                content_features, 
                [target_height, target_width],
                method='bilinear'
            )
            
            target_channels = style_feature.shape[-1]
            if resized_content.shape[-1] != target_channels:
                conv_layer = tf.keras.layers.Conv2D(target_channels, 1, padding='same')
                resized_content = conv_layer(resized_content)

            content_mean, content_var = tf.nn.moments(resized_content, axes=[1, 2], keepdims=True)
            content_std = tf.sqrt(content_var + self.epsilon)
            normalized_content = (resized_content - content_mean) / content_std

            style_mean, style_var = tf.nn.moments(style_feature, axes=[1, 2], keepdims=True)
            style_std = tf.sqrt(style_var + self.epsilon)

            style_mean = tf.broadcast_to(style_mean, tf.shape(normalized_content))
            style_std = tf.broadcast_to(style_std, tf.shape(normalized_content))

            stylized_feature = normalized_content * style_std + style_mean
            
            stylized_feature = tf.cast(stylized_feature, tf.float16)
            stylized_features.append(stylized_feature)

        return stylized_features

### Style and Content Feature Extractors

In [88]:
base = tf.keras.applications.VGG19(weights="imagenet", include_top=False)
base.trainable = False
content_output = [base.get_layer(layer).output for layer in content_layers]
style_output = [base.get_layer(layer).output for layer in style_layers]
content_extractor = tf.keras.Model(base.input, content_output)
style_extractor = tf.keras.Model(base.input, style_output)

In [89]:
# tf.keras.utils.plot_model(style_extractor, "style_extractor.png", show_shapes=True)

In [90]:
# tf.keras.utils.plot_model(content_extractor, "content_extractor.png", show_shapes=True)

### Adaptive Convolutaion Layer

In [91]:
class AdaptiveConv2D(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size=3, epsilon=1e-5, **kwargs):
        super(AdaptiveConv2D, self).__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.epsilon = epsilon
        self.kernel_predictor = None
        
    def build(self, input_shape):
        self.kernel_predictor = Conv2D(self.filters * self.kernel_size **2, strides=1,
                                       padding="same",
                                       kernel_initializer=tf.keras.initializers.RandomNormal(mean=0., stddev=0.02))
        super(AdaptiveConv2D, self).build(input_shape)
    
    def call(self, inputs):
        weight = self.kernel_predictor(inputs)
        weight = tf.reshape(weight, (-1, self.kernel_size, self.kernel_size, self.filters))
        return tf.nn.conv2d(inputs, weight, strides=1, padding="SAME")