In [None]:
import tensorflow as tf
from tensorflow.keras.applications import vgg19, VGG19
from tensorflow.keras.layers import Conv2D, UpSampling2D
from PIL import Image

Funzioni di supporto per manipolare le immagini:


1.   process_path: conververte l'immagine in un tensore 
2.   load_img : aggiunge la dimensione per il batch



In [None]:
def process_path(file_path):
    img = tf.io.read_file(file_path)
    img = tf.image.decode_jpeg(img, channels=3) #decodifica un'immagine jpeg in un tensore di uint8
    img = tf.cast(img, tf.float32)
    return img


def load_img(file_path):
    img = process_path(file_path)
    img = img[tf.newaxis, :] #viene aggiunta la dimensione per il batch
    return img

In [None]:
logDir = 'model' # cartella in cui viene salvato il modello
contentImage = load_img("/content/sailboat_cropped.jpg")
styleImage = load_img("/content/sketch_cropped.png")
outputImage = "risultato.jpg"
alpha = 1.0 #variabile per controllare il content-style trade-off

# Architettura del modello
Nelle celle suguenti viene creata la struttura **Encode -> AdaIn -> Decode** impiegata per effettuare il trasferimento di stile 

In [None]:
class Encoder(tf.keras.models.Model):
    def __init__(self, content_layer):
        super(Encoder, self).__init__()
        vgg = VGG19(include_top=False, weights="imagenet")

        self.vgg = tf.keras.Model(
            [vgg.input], [vgg.get_layer(content_layer).output]
        )
        self.vgg.trainable = False

    def call(self, inputs):
        preprocessed_input = vgg19.preprocess_input(inputs)
        return self.vgg(preprocessed_input)

#come suggerito dall'articolo nel decoder viene impiegato il reflection padding
#per fare si che i bordi dell'immagini rispecchino l'immagine contenuto

class ReflectionPadding2D(tf.keras.layers.Layer):
    def __init__(self, padding=1, **kwargs):
        super(ReflectionPadding2D, self).__init__(**kwargs)
        self.padding = padding

    def compute_output_shape(self, s):
        return s[0], s[1] + 2 * self.padding, s[2] + 2 * self.padding, s[3]

    def call(self, x):
        return tf.pad(
            x,
            [
                [0, 0],
                [self.padding, self.padding],
                [self.padding, self.padding],
                [0, 0],
            ],
            "REFLECT",
        )


def decoder():
    return tf.keras.Sequential(
        [
            ReflectionPadding2D(),
            Conv2D(256, (3, 3), activation="relu"),
            UpSampling2D(size=2),
            ReflectionPadding2D(),
            Conv2D(256, (3, 3), activation="relu"),
            ReflectionPadding2D(),
            Conv2D(256, (3, 3), activation="relu"),
            ReflectionPadding2D(),
            Conv2D(256, (3, 3), activation="relu"),
            ReflectionPadding2D(),
            Conv2D(128, (3, 3), activation="relu"),
            UpSampling2D(size=2),
            ReflectionPadding2D(),
            Conv2D(128, (3, 3), activation="relu"),
            ReflectionPadding2D(),
            Conv2D(64, (3, 3), activation="relu"),
            UpSampling2D(size=2),
            ReflectionPadding2D(),
            Conv2D(64, (3, 3), activation="relu"),
            ReflectionPadding2D(),
            Conv2D(3, (3, 3)),
        ]
    )

#classe principale che racchiude la fase di encode -> AdaIn -> Decode

class TransferNet(tf.keras.Model):
    def __init__(self, content_layer):
        super(TransferNet, self).__init__()
        self.encoder = Encoder(content_layer)
        self.decoder = decoder()

    def encode(self, content_image, style_image, alpha):
        content_feat = self.encoder(content_image)
        style_feat = self.encoder(style_image)

        t = adaptive_instance_normalization(content_feat, style_feat)
        t = alpha * t + (1 - alpha) * content_feat
        return t

    def decode(self, t):
        return self.decoder(t)

    def call(self, content_image, style_image, alpha=1.0):
        t = self.encode(content_image, style_image, alpha)
        g_t = self.decode(t)
        return g_t


def adaptive_instance_normalization(content_feat, style_feat, epsilon=1e-5):
    content_mean, content_variance = tf.nn.moments(
        content_feat, axes=[1, 2], keepdims=True
    )
    style_mean, style_variance = tf.nn.moments(
        style_feat, axes=[1, 2], keepdims=True
    )
    style_std = tf.math.sqrt(style_variance + epsilon)

    norm_content_feat = tf.nn.batch_normalization(
        content_feat,
        mean=content_mean,
        variance=content_variance,
        offset=style_mean,
        scale=style_std,
        variance_epsilon=epsilon,
    )
    return norm_content_feat



L'encoder contiene i livelli della rete vgg19, fino a "block4_conv1" come specificato dall'articolo. Per quanto riguarda i pesi dell'encoder si sono utilizzati quelli già disponibili. I pesi che vengono caricati riguardano solo il Decoder.

In [None]:
from google.colab import drive
import os

# Mount Google Drive
drive.mount('/gdrive', force_remount=True)

# Location of Zip File
drive_path = '/gdrive/MyDrive/arbitrary-style-transfer-master/arbitrary-style-transfer-master/model'
local_path = '/content'

# Copy the zip file and move it up one level (AKA out of the drive folder)
!cp -r '{drive_path}' .

# Navigate to the copied file and unzip it quietly
os.chdir(local_path)

In [None]:
content_layer = "block4_conv1"
transformer = TransferNet(content_layer)
ckpt = tf.train.Checkpoint(transformer=transformer)
ckpt.restore(tf.train.latest_checkpoint(logDir)).expect_partial()
stylized_image = transformer(contentImage, styleImage, alpha=alpha)
stylized_image = tf.cast( tf.squeeze(stylized_image), tf.uint8 ).numpy()

img = Image.fromarray(stylized_image, mode="RGB")
img.save(outputImage)