# **Neural Style Transfer**

##  Preparation

### Imports

In [1]:
import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Dense, Flatten, Activation, Conv2D, MaxPooling2D, Dropout, Rescaling
from tensorflow.keras.models import Sequential, load_model, Model
import kagglehub
import glob

2025-06-07 07:10:33.430441: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1749280233.454665   24525 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1749280233.462141   24525 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-07 07:10:33.497795: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### Global Variables and Policy

In [2]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print("GPU memory growth configuration error:", e)

In [3]:
content_layers = ['block5_conv2']
style_layers = ['block1_conv1', 'block2_conv1', 'block3_conv1', 'block4_conv1']

In [4]:
tf.keras.mixed_precision.set_global_policy("mixed_float16")

In [5]:
IMAGE_SIZE = 256
epochs = 10
batch_size = 2

In [6]:
import gc
tf.keras.backend.clear_session()
gc.collect()

0

### Dataset

In [7]:
path_to_dataset = kagglehub.dataset_download("shaorrran/coco-wikiart-nst-dataset-512-100000")
path_to_dataset

'/home/oslyris/.cache/kagglehub/datasets/shaorrran/coco-wikiart-nst-dataset-512-100000/versions/1'

In [8]:
content_data = tf.keras.preprocessing.image_dataset_from_directory(directory=path_to_dataset+"/content/",
                                                                   labels=None,
                                                                   label_mode=None,
                                                                   color_mode="rgb",
                                                                   image_size=(IMAGE_SIZE, IMAGE_SIZE),
                                                                   shuffle=True,
                                                                   batch_size=batch_size)
style_data = tf.keras.preprocessing.image_dataset_from_directory(directory=path_to_dataset+"/style/",
                                                                 labels=None,
                                                                 label_mode=None,
                                                                 color_mode="rgb",
                                                                 image_size=(IMAGE_SIZE, IMAGE_SIZE),
                                                                 shuffle=True,
                                                                 batch_size=batch_size)

Found 49981 files.


I0000 00:00:1749280239.614107   24525 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 3539 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 4050 Laptop GPU, pci bus id: 0000:01:00.0, compute capability: 8.9


Found 49981 files.


In [9]:
# normalization layer to rescale
normalization_layer = Rescaling(1/255.)

# combinbing and rescaling datasets with parallel processing and prefetching
dataset = tf.data.Dataset.zip((content_data, style_data))
dataset = dataset.map(lambda c, s: (normalization_layer(c), normalization_layer(s)), num_parallel_calls=tf.data.AUTOTUNE).prefetch(buffer_size=tf.data.AUTOTUNE)

## Custom Layers

### Adaptive Instance Normalization

In [10]:
class AdaptiveInstanceNormalization(tf.keras.layers.Layer):
    def __init__(self, epsilon=1e-5, **kwargs):
        super().__init__(**kwargs)
        self.epsilon = epsilon
    
    def build(self, input_shape):
        content_shape, style_shapes = input_shape
        self.channel_convs = []
        for shape in style_shapes:
            if shape[-1] != content_shape[-1]:
                conv = Conv2D(shape[-1], 1, padding='same', dtype='float32')
                self.channel_convs.append(conv)
            else:
                self.channel_convs.append(None)
        super().build(input_shape)
    
    def call(self, inputs):
        content_features, style_features = inputs
        stylized_features = []
        
        content_features = tf.cast(content_features, tf.float32)
        
        for i, style_feature in enumerate(style_features):
            style_feature = tf.cast(style_feature, tf.float32)
            target_height, target_width = tf.shape(style_feature)[1], tf.shape(style_feature)[2]
            
            resized_content = tf.image.resize(
                content_features, [target_height, target_width], method='bilinear')
            
            if self.channel_convs[i] is not None:
                resized_content = self.channel_convs[i](resized_content)
            
            content_mean, content_var = tf.nn.moments(
                resized_content, axes=[1, 2], keepdims=True)
            content_std = tf.sqrt(content_var + self.epsilon)
            normalized_content = (resized_content - content_mean) / content_std
            
            style_mean, style_var = tf.nn.moments(
                style_feature, axes=[1, 2], keepdims=True)
            style_std = tf.sqrt(style_var + self.epsilon)
            
            stylized_feature = normalized_content * style_std + style_mean
            stylized_features.append(stylized_feature)
        
        return stylized_features

### Style and Content Feature Extractors

In [11]:
base = tf.keras.applications.VGG19(weights="imagenet", include_top=False)
base.trainable = False
content_output = [base.get_layer(layer).output for layer in content_layers]
style_output = [base.get_layer(layer).output for layer in style_layers]
content_extractor = tf.keras.Model(base.input, content_output)
style_extractor = tf.keras.Model(base.input, style_output)

In [12]:
# tf.keras.utils.plot_model(style_extractor, "style_extractor.png", show_shapes=True)

In [13]:
# tf.keras.utils.plot_model(content_extractor, "content_extractor.png", show_shapes=True)

### Adaptive Convolutaion Layer

In [14]:
class AdaptiveConv2D(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size=3, **kwargs):
        super().__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
    
    def build(self, input_shape):
        stylized_shape, style_shape = input_shape
        if stylized_shape[1:3] != style_shape[1:3]:
            raise ValueError(
                "Content and style feature spatial dimensions must match")
        
        self.style_encoder = Sequential([
            Conv2D(64, 3, padding='same', activation='relu', dtype='float32'),
            Conv2D(32, 3, padding='same', activation='relu', dtype='float32'),
            tf.keras.layers.GlobalAveragePooling2D(),
            tf.keras.layers.Dense(128, activation='relu', dtype='float32')
        ])
        self.kernel_predictor = tf.keras.layers.Dense(
            self.filters * self.kernel_size * self.kernel_size * stylized_shape[-1],
            kernel_initializer=tf.keras.initializers.RandomNormal(
                mean=0., stddev=0.02),
            dtype='float32'
        )
        super().build(input_shape)
    
    def call(self, inputs):
        stylized_features, style_features = inputs
        style_vector = self.style_encoder(style_features)
        style_vector = style_vector[0:1]
        weight = self.kernel_predictor(style_vector)
        weight = tf.reshape(
            weight, (self.kernel_size, self.kernel_size,
                     stylized_features.shape[-1], self.filters))
        output = tf.nn.conv2d(
            stylized_features, weight, strides=1, padding="SAME")
        return output

In [15]:
test = dataset.take(1)
content, style = next(iter(test))
content = content_extractor(content)
style = style_extractor(style)
normal = AdaptiveInstanceNormalization()([content, style])
adaConv = []
for i in range(4):
    adaConv.append(AdaptiveConv2D(256)([normal[i], style[i]]))

I0000 00:00:1749280243.179059   24525 cuda_dnn.cc:529] Loaded cuDNN version 90300
2025-06-07 07:10:44.074529: W external/local_xla/xla/tsl/framework/bfc_allocator.cc:306] Allocator (GPU_0_bfc) ran out of memory trying to allocate 4.13GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
2025-06-07 07:10:44.333067: W external/local_xla/xla/tsl/framework/bfc_allocator.cc:306] Allocator (GPU_0_bfc) ran out of memory trying to allocate 4.12GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
2025-06-07 07:10:44.567685: W external/local_xla/xla/tsl/framework/bfc_allocator.cc:306] Allocator (GPU_0_bfc) ran out of memory trying to allocate 4.17GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if m

In [16]:
print(f"Content after Extractor: {content.shape}\nStyle after Extractor: {[f.shape for f in normal]}\nStylized Features: {[f.shape for f in normal]}\nAdaptive Conv2D Output: {[f.shape for f in adaConv]}")

Content after Extractor: (2, 16, 16, 512)
Style after Extractor: [TensorShape([2, 256, 256, 64]), TensorShape([2, 128, 128, 128]), TensorShape([2, 64, 64, 256]), TensorShape([2, 32, 32, 512])]
Stylized Features: [TensorShape([2, 256, 256, 64]), TensorShape([2, 128, 128, 128]), TensorShape([2, 64, 64, 256]), TensorShape([2, 32, 32, 512])]
Adaptive Conv2D Output: [TensorShape([2, 256, 256, 256]), TensorShape([2, 128, 128, 256]), TensorShape([2, 64, 64, 256]), TensorShape([2, 32, 32, 256])]


In [18]:
len(adaConv), adaConv[0].shape, adaConv[1].shape, adaConv[2].shape, adaConv[3].shape

(4,
 TensorShape([2, 256, 256, 256]),
 TensorShape([2, 128, 128, 256]),
 TensorShape([2, 64, 64, 256]),
 TensorShape([2, 32, 32, 256]))

In [19]:
len(normal), normal[0].shape, normal[1].shape, normal[2].shape, normal[3].shape

(4,
 TensorShape([2, 256, 256, 64]),
 TensorShape([2, 128, 128, 128]),
 TensorShape([2, 64, 64, 256]),
 TensorShape([2, 32, 32, 512]))

### Decoder