# Synthèse de textures par réseau convolutif (filtres aléatoires)

## Mise en Place

In [None]:
import os
from urllib.request import urlopen
from io import BytesIO

from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import trange, tqdm

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms.functional as TF

import utils

In [None]:
# fetch images
texture_imgnames = ["bois.png", "briques.png", "mur.png",
                    "tissu.png", "nuages.png", "pebbles.jpg", "wall1003.png"]
#TODO use urllib instead 
#import wget
for fname in texture_imgnames:
    os.system("wget -c https://www.idpoisson.fr/galerne/mva/" + fname)

In [None]:
# device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device is", device)

## ranVGG

### VGG

In [None]:
#TODO description de VGG

Dans un premier temps, nous chargeons le modèle VGG pré-entraîné afin d'évaluer notre implémentation sur un réseau dont les poids sont déjà optimisés:

In [None]:
cnn = models.vgg19(pretrained=True).features.to(device).eval()

### Reconstruction d'images

In [None]:
#TODO description théorique  de la content loss

In [None]:
def content_loss(cnn: nn.Module, layer_output_: "list[torch.Tensor]", target_out: torch.Tensor, synthetic_image: "torch.Tensor", weighing_factor: float = None):
    # uses the layer_output variable to get the output of the target layer
    # assumes correctly batched input (ie bchw dimensions)
    # assumes target of shape chw
    # make sure to un/register hooks before/after content_loss, using relu output
    # suitable to use inside a closure (see https://pytorch.org/docs/stable/optim.html)
    # TODO docstring

    # step 1: forward-propagate and get the layer output from the mutable
    cnn(synthetic_image)
    synth_out = layer_output_[0] # add batch dimension
    
    # step 2: compute the loss
    n_feature_maps, feature_height, feature_width = synth_out.shape
    loss = F.mse_loss(synth_out, target_out)

    if weighing_factor is not None:
        loss *= weighing_factor

    return loss


In [None]:
#TODO description de la reconstruction d'images

In [None]:
def reconstruct_image(cnn: nn.Module, target_image: torch.Tensor, target_layer_idx: int, n_steps: int = 20, synth_std: float = 0.5, logging=False, progressbar=True, synth_init: torch.Tensor = None):
    # return reconstructed image and final loss
    # TODO docstring

    if synth_init is not None:
        synthetic_image = synth_init.detach().clone()
        synthetic_image.requires_grad_()
    else:
        # on initialise la pré-image avec un bruit blanc:
        synthetic_image = utils.normal_like(target_image, synth_std)

    layer_output_, handles = utils.register_model_hooks(
        cnn,
        [target_layer_idx]
    )

    # le réseau ne va pas être modifié, donc on peut calculer la cible une fois pour toutes:
    cnn(target_image)
    # on considère la cible comme une vérité terrain:
    target_out = layer_output_[0].detach().clone()

    optimizer = optim.LBFGS([synthetic_image], max_iter=20)

    loss_history = []

    if progressbar:
        iterator = trange(n_steps, desc="Image reconstruction", unit="step")
    else:
        iterator = range(n_steps)
    for _ in iterator:
        def closure():
            # zero out the gradients, else they'll accumulate in synthetic_image.grad
            optimizer.zero_grad()
            loss = content_loss(cnn, layer_output_,
                                target_out, synthetic_image)
            loss.backward()  # backpropagate the loss to the input
            if logging:
                loss_history.append(loss.item())
            return loss
        optimizer.step(closure)
        if progressbar and logging:
            iterator.set_postfix(content_loss=loss_history[-1])

    final_loss = content_loss(cnn, layer_output_, target_out, synthetic_image)

    utils.unregister_model_hooks(handles)

    return synthetic_image, final_loss, loss_history


Testons la reconstruction d'image; He *et al.* comparent leurs résultats aux sorties des couches de pooling de VGG, et nous allons donc utiliser la couche pooling3:

In [None]:
input_image_name = "briques.png"
img_size = 256
target_layer = 18  # pool3

target_image = utils.prep_img(input_image_name, img_size).to(device)
synthetic_image, loss, loss_history = reconstruct_image(
    cnn,
    target_image,
    target_layer,
    n_steps=100,
    logging=True
)

fig, _ = plt.subplots(1, 2, figsize=(15, 10))
fig.axes[0].imshow(utils.to_pil(target_image))
fig.axes[0].set_title("Original")
fig.axes[1].imshow(utils.to_pil(synthetic_image))
fig.axes[1].set_title("Reconstruction")
fig.tight_layout()
plt.show()

Nous arrivons bien à reconstruire l'image d'origine.

In [None]:
# plot loss history
fig, _ = plt.subplots(1, 1, figsize=(15, 10))
fig.axes[0].plot(loss_history)
fig.axes[0].set_xlabel("L-BFGS iterations")
fig.axes[0].set_ylabel("Content loss")
fig.axes[0].set_xscale('log')
plt.show()

Au vu de la courbe ci-dessus, la fonction de perte semble converger pour VGG entre 100 et 1000 itérations de L-BGFS; nous avons donc choisi `n_steps = 20` par défaut, ce qui correspond à 400 itérations de L-BFGS (`optim.LBFGS` effectue par 20 itérations par étape).

### Construction de ranVGG

In [None]:
def build_ranvgg_(cnn: nn.Module, target_image: torch.Tensor, n_samples: int = 20, n_steps: int = 20, synth_std: float = 0.5):
    # remplacement des poids du modèle par un bruit blanc:
    cnn.requires_grad_(False)  # on empêche le modèle d'apprendre
    # on entraîne les couches convolutives:
    layers_to_build = [
        idx for idx, layer in enumerate(cnn) if isinstance(layer, nn.Conv2d)
    ]
    for layer_idx in tqdm(layers_to_build, desc="Building ranVGG", unit="layer"):
        # on vise la couche d'activation (ReLU) pour la reconstruction d'image:
        activation_idx = layer_idx + 1
        conv_layer = cnn[layer_idx]
        best_weight = None
        best_bias = None
        best_loss = float("inf")
        # on utilise la même image synthétique pour comparer tous les échantillons:
        synthetic_image = utils.normal_like(target_image, synth_std)
        with trange(n_samples, leave=False, unit="sample") as pbar:
            for _ in pbar:
                utils.randomize_layer_(conv_layer)
                # on évalue la qualité de notre couche:
                _, loss, _ = reconstruct_image(
                    cnn,
                    target_image,
                    activation_idx,
                    synth_init=synthetic_image,
                    n_steps=n_steps,
                    progressbar=False
                )
                # on ne garde que le gradient de la loss:
                loss = loss.item()

                if loss < best_loss:
                    best_loss = loss
                    best_weight = conv_layer.weight.detach().clone()
                    best_bias = conv_layer.bias.detach().clone()
                    pbar.set_postfix(best_loss=best_loss)
        conv_layer.weight.copy_(best_weight)
        conv_layer.bias.copy_(best_bias)

In [None]:
# TODO does not work properly :(
# construction de ranVGG
input_image_name = "briques.png"
img_size = 256

target_image = utils.prep_img(input_image_name, img_size).to(device)
ranvgg = models.vgg19(pretrained=False).features.to(device).eval()
build_ranvgg_(ranvgg, target_image, n_steps=40, n_samples=20)

In [None]:
# test de ranVGG
target_layer = 18  # pool3

synthetic_image, loss, loss_history = reconstruct_image(
    ranvgg,
    target_image,
    target_layer,
    n_steps=100,
    logging=True
)

fig, _ = plt.subplots(1, 2, figsize=(15, 10))
fig.axes[0].imshow(utils.to_pil(target_image))
fig.axes[0].set_title("Original")
fig.axes[1].imshow(utils.to_pil(synthetic_image))
fig.axes[1].set_title("Reconstruction")
fig.tight_layout()
plt.show()

In [None]:
# plot loss history
fig, _ = plt.subplots(1, 1, figsize=(15, 10))
fig.axes[0].plot(loss_history)
fig.axes[0].set_xlabel("L-BFGS iterations")
fig.axes[0].set_ylabel("Content loss")
fig.axes[0].set_xscale('log')
plt.show()

In [None]:
#TODO stuff below here is for later use

In [None]:
def gramm(tnsr: torch.Tensor) -> torch.Tensor:
    """Computes Gram matrix for the input batch tensor.
    Args: tnsr (torch.Tensor): input tensor of the Size([B, C, H, W]).
    Returns:  G (torch.Tensor): output tensor of the Size([B, C, C]).
    """
    b, c, h, w = tnsr.size()
    F = tnsr.view(b, c, h * w)
    G = torch.bmm(F, F.transpose(1, 2))
    G.div_(h * w)
    return G


def gram_loss(input: torch.Tensor, gramm_target: torch.Tensor, weight: float = 1.0):
    """Computes the MSE loss between the Gram matrix of the input and the target
    Gram matrix. 
    """
    loss = weight * F.mse_loss(gramm(input), gramm_target)
    return loss


#TODO define texture_loss

In [None]:
# layers to use in the texture synthesis:
TARGET_LAYERS = [1, 4, 9, 18, 27]  # 1rst Conv2d (after ReLU) and MaxPool2d
LAYER_NAMES = ["conv1", "pool1", "pool2", "pool3", "pool4"]