Neural Style Transfer

# Importing All Libraries

In [47]:
import tensorflow as tf
import numpy as np
import PIL.Image
import matplotlib.pyplot as plt
import time
from tensorflow.keras import mixed_precision

# GPU and Mixed Precision Setup

In [None]:
# Enable mixed precision for improved performance and reduced memory usage (optional)
mixed_precision.set_global_policy('mixed_float16')


# Configure GPU memory growth to avoid pre-allocating all VRAM

In [49]:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

# 1. Configuration

In [50]:
CONTENT_WEIGHT = 1e4
STYLE_WEIGHT = 1e-2
TV_WEIGHT = 30
STEPS = 1000
LEARNING_RATE = 0.02
MAX_DIM = 384

# 2. Image Handling Utilities

In [51]:
def load_img(path, max_dim=MAX_DIM):
    img = PIL.Image.open(path).convert('RGB')
    img.thumbnail((max_dim, max_dim))
    img = np.array(img)
    img = tf.keras.applications.vgg19.preprocess_input(img)
    return tf.expand_dims(tf.convert_to_tensor(img, dtype=tf.float16), 0)

def deprocess_img(processed_img):
    img = processed_img.numpy().squeeze()
    img += [103.939, 116.779, 123.68]
    img = img[:, :, ::-1]
    return np.clip(img, 0, 255).astype('uint8')

# 3. Model Setup

In [52]:
def get_model():
    vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet')
    vgg.trainable = False
    content_layer = 'block5_conv2'
    style_layers = [f'block{i}_conv1' for i in range(1, 6)]
    outputs = [vgg.get_layer(content_layer).output] + [vgg.get_layer(layer).output for layer in style_layers]
    return tf.keras.Model(vgg.input, outputs)

def gram_matrix(tensor):
    """Compute the Gram matrix."""
    channels = int(tensor.shape[-1])
    a = tf.reshape(tensor, [-1, channels])
    return tf.matmul(a, a, transpose_a=True) / tf.cast(tf.shape(a)[0], tf.float16)


# 4. Loss Functions

In [53]:
def content_loss(content, generated):
    return tf.reduce_mean(tf.square(tf.cast(content, tf.float16) - tf.cast(generated, tf.float16)))

def style_loss(style, generated):
    return tf.reduce_mean(tf.square(tf.cast(style, tf.float16) - tf.cast(generated, tf.float16)))

def total_variation_loss(image):
    x_diff = image[:, 1:, :, :] - image[:, :-1, :, :]
    y_diff = image[:, :, 1:, :] - image[:, :, :-1, :]
    return tf.reduce_sum(tf.abs(x_diff)) + tf.reduce_sum(tf.abs(y_diff))

# 5. Load Images and Initialize Generated Image

In [54]:
content_image = load_img('content.jpg')
style_image = load_img('style.jpg')
# Ensure generated image is float16
generated_image = tf.Variable(content_image, dtype=tf.float16)

# Get model
model = get_model()

# 6. Extract Feature Targets from the Model

In [55]:
content_target = tf.cast(model(content_image)[0], tf.float16) 
style_targets = [tf.cast(gram_matrix(style_output), tf.float16) for style_output in model(style_image)[1:]]

# 7. Optimizer Setup

In [56]:
optimizer = tf.optimizers.Adam(learning_rate=LEARNING_RATE)

# 8. Training Step (compiled with tf.function for performance)

In [57]:
@tf.function
def train_step(generated_image, content_target, style_targets):
    with tf.GradientTape() as tape:
        outputs = model(tf.cast(generated_image, tf.float16))
        generated_content = tf.cast(outputs[0], tf.float16)
        generated_styles = [tf.cast(gram_matrix(style_output), tf.float16) for style_output in outputs[1:]]  # Cast

        # Compute losses
        c_loss = content_loss(content_target, generated_content)
        s_loss = tf.add_n([style_loss(style_target, gen_style) for style_target, gen_style in zip(style_targets, generated_styles)]) / len(style_targets)
        tv_loss = total_variation_loss(generated_image)

        total_loss = CONTENT_WEIGHT * c_loss + STYLE_WEIGHT * s_loss + TV_WEIGHT * tv_loss

    gradients = tape.gradient(total_loss, generated_image)
    optimizer.apply_gradients([(gradients, generated_image)])
    return total_loss, c_loss, s_loss, tv_loss

# 9. Training Loop

In [None]:
start_time = time.time()
for step in range(STEPS):
    total_loss, c_loss, s_loss, tv_loss = train_step(generated_image, content_target, style_targets)
    if step % 100 == 0:
        print(f"Step {step}: Total Loss={total_loss:.2e}, Content Loss={c_loss:.2e}, Style Loss={s_loss:.2e}, TV Loss={tv_loss:.2e}")
print(f"Total time: {time.time()-start_time:.2f} seconds")

Step 0: Total Loss=nan, Content Loss=0.00e+00, Style Loss=nan, TV Loss=inf


# 10. Display and Save the Result

In [None]:
result = deprocess_img(generated_image)
plt.imshow(result)
plt.axis('off')
plt.show()