# **Neural Style Transfer**

##  Preparation

### Imports

In [1]:
import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Dense, Flatten, Activation, Conv2D, MaxPooling2D, Dropout, Rescaling
from tensorflow.keras.models import Sequential, load_model, Model
import kagglehub
import glob

2025-06-06 16:47:41.664310: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1749228461.802557    1156 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1749228461.838349    1156 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-06 16:47:42.179298: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### Global Variables and Policy

In [2]:
content_layers = ['block5_conv2']
style_layers = ['block1_conv1', 'block2_conv1', 'block3_conv1', 'block4_conv1']

In [3]:
tf.keras.mixed_precision.set_global_policy("mixed_float16")

In [4]:
IMAGE_SIZE = 256
epochs = 10
batch_size = 2

In [5]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print("GPU memory growth configuration error:", e)

In [6]:
import gc
tf.keras.backend.clear_session()
gc.collect()

0

### Dataset

In [7]:
path_to_dataset = kagglehub.dataset_download("shaorrran/coco-wikiart-nst-dataset-512-100000")
path_to_dataset

'/home/oslyris/.cache/kagglehub/datasets/shaorrran/coco-wikiart-nst-dataset-512-100000/versions/1'

In [8]:
content_data = tf.keras.preprocessing.image_dataset_from_directory(directory=path_to_dataset+"/content/",
                                                                   labels=None,
                                                                   label_mode=None,
                                                                   color_mode="rgb",
                                                                   image_size=(IMAGE_SIZE, IMAGE_SIZE),
                                                                   shuffle=True,
                                                                   batch_size=batch_size)
style_data = tf.keras.preprocessing.image_dataset_from_directory(directory=path_to_dataset+"/style/",
                                                                 labels=None,
                                                                 label_mode=None,
                                                                 color_mode="rgb",
                                                                 image_size=(IMAGE_SIZE, IMAGE_SIZE),
                                                                 shuffle=True,
                                                                 batch_size=batch_size)

Found 49981 files.


I0000 00:00:1749228473.119444    1156 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 3539 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 4050 Laptop GPU, pci bus id: 0000:01:00.0, compute capability: 8.9


Found 49981 files.


In [9]:
# normalization layer to rescale
normalization_layer = Rescaling(1/255.)

# combinbing and rescaling datasets with parallel processing and prefetching
dataset = tf.data.Dataset.zip((content_data, style_data))
dataset = dataset.map(lambda c, s: (normalization_layer(c), normalization_layer(s)), num_parallel_calls=tf.data.AUTOTUNE).prefetch(buffer_size=tf.data.AUTOTUNE)

## Custom Layers

### Adaptive Instance Normalization

In [10]:
class AdaptiveInstanceNormalization(tf.keras.layers.Layer):
    def __init__(self, epsilon=1e-5, **kwargs):
        super().__init__(**kwargs)
        self.epsilon = epsilon

    def call(self, inputs):
        content_features, style_features = inputs
        stylized_features = []

        content_features = tf.cast(content_features, tf.float32)
        
        for style_feature in style_features:
            style_feature = tf.cast(style_feature, tf.float32)
            
            target_height, target_width = tf.shape(style_feature)[1], tf.shape(style_feature)[2]
            
            resized_content = tf.image.resize(
                content_features, 
                [target_height, target_width],
                method='bilinear'
            )
            
            target_channels = style_feature.shape[-1]
            if resized_content.shape[-1] != target_channels:
                conv_layer = tf.keras.layers.Conv2D(target_channels, 1, padding='same')
                resized_content = conv_layer(resized_content)

            content_mean, content_var = tf.nn.moments(resized_content, axes=[1, 2], keepdims=True)
            content_std = tf.sqrt(content_var + self.epsilon)
            normalized_content = (resized_content - content_mean) / content_std

            style_mean, style_var = tf.nn.moments(style_feature, axes=[1, 2], keepdims=True)
            style_std = tf.sqrt(style_var + self.epsilon)

            style_mean = tf.broadcast_to(style_mean, tf.shape(normalized_content))
            style_std = tf.broadcast_to(style_std, tf.shape(normalized_content))

            stylized_feature = normalized_content * style_std + style_mean
            
            stylized_feature = tf.cast(stylized_feature, tf.float16)
            stylized_features.append(stylized_feature)

        return stylized_features

### Style and Content Feature Extractors

In [11]:
base = tf.keras.applications.VGG19(weights="imagenet", include_top=False)
base.trainable = False
content_output = [base.get_layer(layer).output for layer in content_layers]
style_output = [base.get_layer(layer).output for layer in style_layers]
content_extractor = tf.keras.Model(base.input, content_output)
style_extractor = tf.keras.Model(base.input, style_output)

In [12]:
# tf.keras.utils.plot_model(style_extractor, "style_extractor.png", show_shapes=True)

In [13]:
# tf.keras.utils.plot_model(content_extractor, "content_extractor.png", show_shapes=True)

### Adaptive Convolutaion Layer

In [None]:
class AdaptiveConv2D(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size=3, **kwargs):
        super().__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
    
    def build(self, input_shape):
        content_shape, style_shape = input_shape
        self.style_encoder = tf.keras.Sequential([
            tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu'),
            tf.keras.layers.Conv2D(32, 3, padding='same', activation='relu'),
            tf.keras.layers.GlobalAveragePooling2D(),
            tf.keras.layers.Dense(128, activation='relu')
        ])
        self.kernel_predictor = tf.keras.layers.Dense(
            self.filters * self.kernel_size * self.kernel_size * content_shape[-1],
            kernel_initializer=tf.keras.initializers.RandomNormal(mean=0., stddev=0.02),
            dtype='float32'
        )
        super().build(input_shape)
    
    def call(self, inputs):
        content_features, style_features = inputs
        style_vector = self.style_encoder(style_features)
        weight = self.kernel_predictor(style_vector)
        
        batch_size = tf.shape(content_features)[0]
        in_channels = content_features.shape[-1]
        
        weight = tf.reshape(weight, [batch_size, self.kernel_size, self.kernel_size, in_channels, self.filters])
        
        def single_conv(args):
            x, w = args
            x = tf.expand_dims(x, 0)
            y = tf.nn.conv2d(x, w, strides=1, padding="SAME")
            return tf.squeeze(y, 0)
        
        output = tf.map_fn(single_conv, (content_features, weight), dtype=tf.float32)
        return output

In [15]:
test =  dataset.take(1)
content, style = next(iter(test))
content = content_extractor(content)
style = style_extractor(style)
adain = AdaptiveInstanceNormalization()([content, style])

I0000 00:00:1749228476.995829    1156 cuda_dnn.cc:529] Loaded cuDNN version 90300
2025-06-06 16:47:58.346702: W external/local_xla/xla/tsl/framework/bfc_allocator.cc:306] Allocator (GPU_0_bfc) ran out of memory trying to allocate 4.13GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
2025-06-06 16:47:58.627763: W external/local_xla/xla/tsl/framework/bfc_allocator.cc:306] Allocator (GPU_0_bfc) ran out of memory trying to allocate 4.12GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
2025-06-06 16:47:58.969894: W external/local_xla/xla/tsl/framework/bfc_allocator.cc:306] Allocator (GPU_0_bfc) ran out of memory trying to allocate 4.17GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if m

In [16]:
len(adain), adain[0].dtype

(4, tf.float16)

In [17]:
adain[0].shape, adain[1].shape, adain[2].shape, adain[3].shape

(TensorShape([2, 256, 256, 64]),
 TensorShape([2, 128, 128, 128]),
 TensorShape([2, 64, 64, 256]),
 TensorShape([2, 32, 32, 512]))

In [18]:
style = [tf.cast(style[i], tf.float16) for i in range(4)]

In [19]:
style[0].shape, style[1].shape, style[2].shape, style[3].shape

(TensorShape([2, 256, 256, 64]),
 TensorShape([2, 128, 128, 128]),
 TensorShape([2, 64, 64, 256]),
 TensorShape([2, 32, 32, 512]))

In [20]:
adaConv = AdaptiveConv2D(256)([adain[0], style[0]])
    
adaConv.shape

TensorShape([2, 256, 256, 256])