In [1]:
import tensorflow as tf
import tensorflow.keras.layers as tfl
import tensorflow.keras.backend as K
from tensorflow.keras import Model

import matplotlib.pyplot as plt
from PIL import Image
import glob
import numpy as np

# Autoencoder transfer style
Ce notebook est une implémentation de l'architecture décrite par A. Sanakoyeu, D. Kotovenko, S. Lang et B. Ommer dans *A Style-Aware Content Loss for Real-time HD Style Transfer* (2018).

<p align="center">
    <img src='../latex/images/autoencoder.png'  />
</p>

In [66]:
input_size = (128, 128, 3)
input_content = tfl.Input(input_size)
input_style = tfl.Input(input_size)

## Encodeur

In [67]:
c1 = tfl.Conv2D(32, (3, 3), strides=1, padding="same")(input_content)
n1 = tfl.BatchNormalization()(c1)
a1 = tfl.LeakyReLU(alpha=0.2)(n1)

c2 = tfl.Conv2D(64, (3, 3), strides=2, padding="same")(a1)
n2 = tfl.BatchNormalization()(c2)
a2 = tfl.LeakyReLU(alpha=0.2)(n2)

c3 = tfl.Conv2D(128, (3, 3), strides=2, padding="same")(a2)
n3 = tfl.BatchNormalization()(c3)
a3 = tfl.LeakyReLU(alpha=0.2)(n3)

c4 = tfl.Conv2D(512, (3, 3), strides=2, padding="same")(a3)
n4 = tfl.BatchNormalization()(c4)
a4 = tfl.LeakyReLU(alpha=0.2)(n4)

c5 = tfl.Conv2D(64, (3, 3), strides=2, padding="same")(a4)
n5 = tfl.BatchNormalization()(c5)
a5 = tfl.LeakyReLU(alpha=0.2)(n5)

## Espace latent

In [68]:
latent_dim = 16
flat = tfl.Flatten()(a5)
latent = tfl.Dense(latent_dim)(flat)

In [69]:
encoder = Model(input_content, latent)

In [77]:
encoder.summary()

Model: "functional_19"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_10 (InputLayer)        [(None, 128, 128, 3)]     0         
_________________________________________________________________
conv2d_71 (Conv2D)           (None, 128, 128, 32)      896       
_________________________________________________________________
batch_normalization_30 (Batc (None, 128, 128, 32)      128       
_________________________________________________________________
leaky_re_lu_30 (LeakyReLU)   (None, 128, 128, 32)      0         
_________________________________________________________________
conv2d_72 (Conv2D)           (None, 64, 64, 64)        18496     
_________________________________________________________________
batch_normalization_31 (Batc (None, 64, 64, 64)        256       
_________________________________________________________________
leaky_re_lu_31 (LeakyReLU)   (None, 64, 64, 64)      

## Décodeur

In [78]:
decod_input = tfl.Input(latent_dim)
shape_before_flat = K.int_shape(a5)
shape_after_flat = K.int_shape(flat)
d1 = tfl.Dense(shape_after_flat[1])(decod_input)

reshape = tfl.Reshape((shape_before_flat[1], shape_before_flat[2], shape_before_flat[3]))(d1)

In [79]:
# https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/He_Deep_Residual_Learning_CVPR_2016_paper.pdf
def residual_layer(input_layer):
    res1 = tfl.Conv2D(64, (3, 3), strides=1, padding="same", activation="relu")(input_layer)
    res2 = tfl.Conv2D(64, (3, 3), strides=1, padding="same")(res1)
    add = tf.math.add(res2, input_layer)
    out = tfl.ReLU()(add)
    return out

In [80]:
r1 = residual_layer(reshape)
r2 = residual_layer(r1)
r3 = residual_layer(r2)
r4 = residual_layer(r3)
r5 = residual_layer(r4)
r6 = residual_layer(r5)
r7 = residual_layer(r6)
r8 = residual_layer(r7)
r9 = residual_layer(r8)

In [81]:
def upsampling_block(input_layer, filters):
    layer_size = K.int_shape(input_layer)
    resize = tf.image.resize(input_layer, (layer_size[1]*2, layer_size[1]*2) , method="nearest")
    conv = tfl.Conv2D(filters, (1, 1), strides=1, padding="same")(resize)
    return conv

In [82]:
up1 = upsampling_block(r9, 64)
up2 = upsampling_block(up1, 64)
up3 = upsampling_block(up2, 64)
up4 = upsampling_block(up3, 3)

In [83]:
decoder = Model(decod_input, up4)

In [84]:
decoder.summary()

Model: "functional_22"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_14 (InputLayer)           [(None, 16)]         0                                            
__________________________________________________________________________________________________
dense_12 (Dense)                (None, 4096)         69632       input_14[0][0]                   
__________________________________________________________________________________________________
reshape_5 (Reshape)             (None, 8, 8, 64)     0           dense_12[0][0]                   
__________________________________________________________________________________________________
conv2d_98 (Conv2D)              (None, 8, 8, 64)     36928       reshape_5[0][0]                  
______________________________________________________________________________________

## Discriminator

In [86]:
input_discriminator = tfl.Input(input_size)
conv1 = tfl.Conv2D(64, (3, 3), strides=1, padding="same", activation="relu")(input_discriminator)
conv2 = tfl.Conv2D(64, (3, 3), strides=1, padding="same", activation="relu")(conv1)
conv3 = tfl.Conv2D(64, (3, 3), strides=1, padding="same", activation="relu")(conv2)
conv4 = tfl.Conv2D(64, (3, 3), strides=1, padding="same", activation="relu")(conv3)
conv5 = tfl.Conv2D(64, (3, 3), strides=1, padding="same", activation="relu")(conv4)
conv6 = tfl.Conv2D(64, (3, 3), strides=1, padding="same", activation="relu")(conv5)
conv7 = tfl.Conv2D(64, (3, 3), strides=1, padding="same", activation="sigmoid")(conv6)

In [87]:
discriminator = Model(input_discriminator, conv7)

In [88]:
discriminator.summary()

Model: "functional_24"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_15 (InputLayer)        [(None, 128, 128, 3)]     0         
_________________________________________________________________
conv2d_120 (Conv2D)          (None, 128, 128, 64)      1792      
_________________________________________________________________
conv2d_121 (Conv2D)          (None, 128, 128, 64)      36928     
_________________________________________________________________
conv2d_122 (Conv2D)          (None, 128, 128, 64)      36928     
_________________________________________________________________
conv2d_123 (Conv2D)          (None, 128, 128, 64)      36928     
_________________________________________________________________
conv2d_124 (Conv2D)          (None, 128, 128, 64)      36928     
_________________________________________________________________
conv2d_125 (Conv2D)          (None, 128, 128, 64)    

## Les fonctions de coût
### Style-Aware Content Loss
$$
\mathcal{L}_c(E, G)=\underset{x\sim p_{X}(x)}{\mathbb{E}}\left[\frac{1}{d}|| E(x) - E(G(E(x))) ||^2_2\right]
$$
Elle représente la distance euclidienne normalisée entre l'encodage de l'image d'example $x$ et l'encodage de l'image reconstruite $G(E(x))$

In [90]:
def style_aware_content_loss():
    pass

### Transformed Image Loss
$$
\mathcal{L}_T(E,G) = \underset{x\sim p_{X}(x)}{\mathbb{E}}\left[\frac{1}{CHW}|| T(x) - T(G(E(x))) ||^2_2\right]
$$
Elle représente la différence entre l'image après une transformation T et l'image reconstruite après la même transformation.

In [92]:
def transformed_image_loss():
    pass

### Discriminator Loss
Standard Adversarial Discriminator

$$
\mathcal{L}_D(E, G, D) = \underset{x\sim p_{Y}(y)}{\mathbb{E}}\left[\log D(y)\right] + \underset{x\sim p_{X}(x)}{\mathbb{E}}\left[\log (1 - D(G(E(x))))\right]
$$

In [93]:
def discriminator_loss():
    pass

In [94]:
def network_loss():
    return discriminator_loss() + transformed_image_loss() + style_aware_content_loss()

## Entraînement

In [96]:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002)

In [None]:
epochs = 300000
for epoch in range(epochs):
    for image, content in train:
        with tf.GradientTape() as tape:
            out_encoder = encoder(content, training=True)
            out_decoder = decoder(out_encoder, training=True)
            encode_decoded = encoder(out_decoder, training=True)
            
            discr1 = discriminator(content, training=True)
            discr2 = discriminator(out_decoder, training=True)
            discr3 = discriminator(style, training=True)
            
            discr_loss = discriminator_loss()
            transformed_loss = transformed_image_loss()
            style_aware_loss = style_aware_content_loss()
            
            loss_value = discr_loss + transformed_loss + style_aware_loss
            
        grads = tape.gradient(loss_value, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))
    
    if epoch == 200000:
        optimizer.lr.assign(0.00002)