In [1]:
import os 
import tensorflow as tf
from tensorflow.keras.applications import vgg19, VGG19
from tensorflow.keras.layers import Conv2D, UpSampling2D

Archittettura della rete neurale per il trasferimento di stile come descritta nell'articolo

In [2]:
class ReflectionPadding2D(tf.keras.layers.Layer):
  
    def __init__(self, padding=1, **kwargs):
        super(ReflectionPadding2D, self).__init__(**kwargs)
        self.padding = padding

    def compute_output_shape(self, s):
        return s[0], s[1] + 2 * self.padding, s[2] + 2 * self.padding, s[3]

    def call(self, x):
        return tf.pad(
            x,
            [
                [0, 0],
                [self.padding, self.padding],
                [self.padding, self.padding],
                [0, 0],
            ],
            "REFLECT",
        )


class TransferNet(tf.keras.Model):
    def __init__(self, content_layer):
        super(TransferNet, self).__init__()
        self.encoder = Encoder(content_layer)
        self.decoder = decoder()

    def encode(self, content_image, style_image, alpha):
        content_feat = self.encoder(content_image)
        style_feat = self.encoder(style_image)

        t = adaptive_instance_normalization(content_feat, style_feat)
        t = alpha * t + (1 - alpha) * content_feat
        return t

    def decode(self, t):
        return self.decoder(t)

    def call(self, content_image, style_image, alpha=1.0):
        t = self.encode(content_image, style_image, alpha)
        g_t = self.decode(t)
        return g_t


def adaptive_instance_normalization(content_feat, style_feat, epsilon=1e-5):
    content_mean, content_variance = tf.nn.moments(
        content_feat, axes=[1, 2], keepdims=True
    )
    style_mean, style_variance = tf.nn.moments(
        style_feat, axes=[1, 2], keepdims=True
    )
    style_std = tf.math.sqrt(style_variance + epsilon)

    norm_content_feat = tf.nn.batch_normalization(
        content_feat,
        mean=content_mean,
        variance=content_variance,
        offset=style_mean,
        scale=style_std,
        variance_epsilon=epsilon,
    )
    return norm_content_feat


class Encoder(tf.keras.models.Model):
    def __init__(self, content_layer):
        super(Encoder, self).__init__()
        vgg = VGG19(include_top=False, weights="imagenet")

        self.vgg = tf.keras.Model(
            [vgg.input], [vgg.get_layer(content_layer).output]
        )
        self.vgg.trainable = False

    def call(self, inputs):
        preprocessed_input = vgg19.preprocess_input(inputs)
        return self.vgg(preprocessed_input)


def decoder():
    return tf.keras.Sequential(
        [
            ReflectionPadding2D(),
            Conv2D(256, (3, 3), activation="relu"),
            UpSampling2D(size=2),
            ReflectionPadding2D(),
            Conv2D(256, (3, 3), activation="relu"),
            ReflectionPadding2D(),
            Conv2D(256, (3, 3), activation="relu"),
            ReflectionPadding2D(),
            Conv2D(256, (3, 3), activation="relu"),
            ReflectionPadding2D(),
            Conv2D(128, (3, 3), activation="relu"),
            UpSampling2D(size=2),
            ReflectionPadding2D(),
            Conv2D(128, (3, 3), activation="relu"),
            ReflectionPadding2D(),
            Conv2D(64, (3, 3), activation="relu"),
            UpSampling2D(size=2),
            ReflectionPadding2D(),
            Conv2D(64, (3, 3), activation="relu"),
            ReflectionPadding2D(),
            Conv2D(3, (3, 3)),
        ]
    )


class VGG(tf.keras.models.Model):
    def __init__(self, content_layer, style_layers):
        super(VGG, self).__init__()
        vgg = VGG19(include_top=False, weights="imagenet")

        content_output = vgg.get_layer(content_layer).output
        style_outputs = [vgg.get_layer(name).output for name in style_layers]

        self.vgg = tf.keras.Model(
            [vgg.input], [content_output, style_outputs]
        )
        self.vgg.trainable = False

    def call(self, inputs):
        preprocessed_input = vgg19.preprocess_input(inputs)
        content_outputs, style_outputs = self.vgg(preprocessed_input)
        return content_outputs, style_outputs


Variabili di controllo

In [3]:
logDir = '/content/drive/MyDrive/UpdatedModel0'        # cartella in cui si trova il modello da addestrare (da zero o da un modello parzialmente addestrato)
lr =1e-4           #learning-rate
lrDecay = 5e-5     #learning rate Decay
imageSize= 256     # dimensione delle immagini utilizzate per il "random crop"
batchSize = 8
contentWeight = 1
styleWeight = 10
logFreq = 50 #500      # frequenza con cui vengono stampate le informazioni del training

In [4]:
content_layer = "block4_conv1"  # relu-4-1

style_layers = [
        "block1_conv1",  # relu1-1
        "block2_conv1",  # relu2-1
        "block3_conv1",  # relu3-1
        "block4_conv1",  # relu4-1
    ]

In [5]:
vgg = VGG(content_layer, style_layers)
transformer = TransferNet(content_layer)

Funzioni ausiliarie per modificare le immagini nel dataset.  
Le immagini verranno ridimensionate e per eseguire il training verra' scelta una regione casuale di dimensione casuale specificata dalla variabile imageSize.

In [6]:
def resize(img, min_size=512):
    """ Resize image and keep aspect ratio """
    width, height, _ = tf.unstack(tf.shape(img), num=3)
    if height < width:
        new_height = min_size
        new_width = int(width * new_height / height)
    else:
        new_width = min_size
        new_height = int(height * new_width / width)

    img = tf.image.resize(img, size=(new_width, new_height))
    return img

def resize_and_crop(img, min_size):
    img = resize(img, min_size=min_size)
    img = tf.image.random_crop(img, size=(imageSize, imageSize, 3))
    img = tf.cast(img, tf.float32)
    return img

def process_content(file_path):
    #img = features["image"]
    #img = resize_and_crop(img, min_size=286)
    img = tf.io.read_file(file_path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = resize_and_crop(img, min_size=286)
    return img

def process_style(file_path):
    img = tf.io.read_file(file_path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = resize_and_crop(img, min_size=512)
    return img

E' necessario dividere il dataset del contenuto e quello dello stile in più cartelle in quanto colab non riesce a leggere molti file in una sola volta.  
Successivamente si uniscono insieme tutti i dataset in uno unico.

In [7]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

# 4 dataset per il contenuto
contentDatasetPath1 = '/content/drive/MyDrive/DatesetContenuto/train2014_20000/001'
contentDatasetPath2 = '/content/drive/MyDrive/DatesetContenuto/train2014_20000/002'
contentDatasetPath3 = '/content/drive/MyDrive/DatesetContenuto/train2014_20000/003'
contentDatasetPath4 = '/content/drive/MyDrive/DatesetContenuto/train2014_20000/004'

#4 dataset per lo stile

styleDatasetPath1 = '/content/drive/MyDrive/DatasetStile/train_20000/001'
styleDatasetPath2 = '/content/drive/MyDrive/DatasetStile/train_20000/002'
styleDatasetPath3 = '/content/drive/MyDrive/DatasetStile/train_20000/003'
styleDatasetPath4 = '/content/drive/MyDrive/DatasetStile/train_20000/004'



ds_coco = (
    tf.data.Dataset.list_files(os.path.join(contentDatasetPath1, "*.jpg"))
    .concatenate(tf.data.Dataset.list_files(os.path.join(contentDatasetPath2, "*.jpg")))
    .concatenate(tf.data.Dataset.list_files(os.path.join(contentDatasetPath3, "*.jpg")))
    .concatenate(tf.data.Dataset.list_files(os.path.join(contentDatasetPath4, "*.jpg")))
    .map(process_content, num_parallel_calls=AUTOTUNE)
    # Ignore too large or corrupt image files
    .apply(tf.data.experimental.ignore_errors()) 
    .repeat()
    .batch(batchSize)
    .prefetch(AUTOTUNE)
    
)

ds_pbn = (
    tf.data.Dataset.list_files(os.path.join(styleDatasetPath1, "*.jpg"))
    .concatenate(tf.data.Dataset.list_files(os.path.join(styleDatasetPath2, "*.jpg")))
    .concatenate(tf.data.Dataset.list_files(os.path.join(styleDatasetPath3, "*.jpg")))
    .concatenate(tf.data.Dataset.list_files(os.path.join(styleDatasetPath4, "*.jpg")))
    .map(process_style, num_parallel_calls=AUTOTUNE)
    # Ignore too large or corrupt image files
    .apply(tf.data.experimental.ignore_errors())
    .repeat()
    .batch(batchSize)
    .prefetch(AUTOTUNE)
    
)


ds = tf.data.Dataset.zip((ds_coco, ds_pbn))

In [8]:
optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
print(optimizer.learning_rate)
ckpt = tf.train.Checkpoint(optimizer=optimizer, transformer=transformer)

manager = tf.train.CheckpointManager(ckpt, logDir, max_to_keep=1)
ckpt.restore(manager.latest_checkpoint).expect_partial()
print(optimizer.learning_rate)
if manager.latest_checkpoint:
    print(f"Restored from {manager.latest_checkpoint}")
else:
    print("Initializing from scratch.")

summary_writer = tf.summary.create_file_writer(logDir)

train_loss = tf.keras.metrics.Mean(name="train_loss")
train_style_loss = tf.keras.metrics.Mean(name="train_style_loss")
train_content_loss = tf.keras.metrics.Mean(name="train_content_loss")

<tf.Variable 'learning_rate:0' shape=() dtype=float32, numpy=1e-04>
<tf.Variable 'learning_rate:0' shape=() dtype=float32, numpy=9.975062e-05>
Restored from /content/drive/MyDrive/UpdatedModel0/ckpt-114


Funzioni di perdita 

In [9]:
def mean_std_loss(feat, feat_stylized, epsilon=1e-5):
    feat_mean, feat_variance = tf.nn.moments(feat, axes=[1, 2])
    feat_stylized_mean, feat_stylized_variance = tf.nn.moments(
        feat_stylized, axes=[1, 2]
    )
    feat_std = tf.math.sqrt(feat_variance + epsilon)
    feat_stylized_std = tf.math.sqrt(feat_stylized_variance + epsilon)

    loss = tf.losses.mse(feat_stylized_mean, feat_mean) + tf.losses.mse(
        feat_stylized_std, feat_std
    )
    return loss

def style_loss(feat, feat_stylized):
    return tf.reduce_sum(
        [
            mean_std_loss(f, f_stylized)
            for f, f_stylized in zip(feat, feat_stylized)
        ]
    )


def content_loss(feat, feat_stylized):
    return tf.reduce_mean(tf.square(feat - feat_stylized), axis=[1, 2, 3])

In [10]:
@tf.function
def train_step(content_img, style_img):
    t = transformer.encode(content_img, style_img, alpha=1.0)

    with tf.GradientTape() as tape:
        stylized_img = transformer.decode(t)

        _, style_feat_style = vgg(style_img)
        content_feat_stylized, style_feat_stylized = vgg(stylized_img)

        tot_style_loss = styleWeight * style_loss(style_feat_style, style_feat_stylized)
        tot_content_loss = contentWeight * content_loss(t, content_feat_stylized)
        loss = tot_style_loss + tot_content_loss

    gradients = tape.gradient(loss, transformer.trainable_variables)
    optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))

    train_loss(loss)
    train_style_loss(tot_style_loss)
    train_content_loss(tot_content_loss)

In [None]:
for step, (content_images, style_images) in enumerate(ds):
    new_lr = lr / (1.0 + lrDecay * step )
    optimizer.learning_rate.assign(new_lr)

    train_step(content_images, style_images)

    if step % logFreq == 0:
        with summary_writer.as_default():
            tf.summary.scalar("loss/total", train_loss.result(), step=step)
            tf.summary.scalar("loss/style", train_style_loss.result(), step=step)
            tf.summary.scalar("loss/content", train_content_loss.result(), step=step)
            print(
                f"Step {step}, "
                f"Loss: {train_loss.result()}, "
                f"Style Loss: {train_style_loss.result()}, "
                f"Content Loss: {train_content_loss.result()}"
            )
            print(f"Saved checkpoint: {manager.save()}")

            train_loss.reset_states()
            train_style_loss.reset_states()
            train_content_loss.reset_states()

Step 0, Loss: 3648751.25, Style Loss: 3191310.25, Content Loss: 457441.03125
Saved checkpoint: /content/drive/MyDrive/UpdatedModel0/ckpt-115
Step 50, Loss: 3552482.0, Style Loss: 3100457.5, Content Loss: 452024.6875
Saved checkpoint: /content/drive/MyDrive/UpdatedModel0/ckpt-116
Step 100, Loss: 3159343.0, Style Loss: 2724058.5, Content Loss: 435284.3125
Saved checkpoint: /content/drive/MyDrive/UpdatedModel0/ckpt-117
Step 150, Loss: 3525057.5, Style Loss: 3066285.0, Content Loss: 458771.96875
Saved checkpoint: /content/drive/MyDrive/UpdatedModel0/ckpt-118
Step 200, Loss: 3633692.25, Style Loss: 3174478.5, Content Loss: 459213.84375
Saved checkpoint: /content/drive/MyDrive/UpdatedModel0/ckpt-119
Step 250, Loss: 3200373.75, Style Loss: 2753378.25, Content Loss: 446995.46875
Saved checkpoint: /content/drive/MyDrive/UpdatedModel0/ckpt-120
