# Neural Style Transfer

We'll perform style transfer between images.

## Select Imput Image

We begin by loading the example image titled `neuraltissue_with_colorlabels.png`, of which we crop a patch of 256 by 256 pixels.

This image is sourced from the Drosophila ssTEM dataset, which is publicly available on Figshare: [Segmented anisotropic ssTEM dataset of neural tissue](https://figshare.com/articles/dataset/Segmented_anisotropic_ssTEM_dataset_of_neural_tissue/856713). This dataset provides a detailed view of neural tissue, aiding in the study of neural structures and patterns. The image can also be downloaded from the corresponding GitHub repository at [this link](http://github.com/unidesigner/groundtruth-drosophila-vnc), which offers additional resources and information related to the Drosophila ssTEM dataset.

In [None]:
from PIL import Image

content = Image.open("neuraltissue_with_colorlabels.png"
                   ).convert('RGB').crop((100, 170, 100 + 256, 170 + 256))

In [None]:
from matplotlib import pyplot as plt

plt.imshow(content)
plt.axis('off')
plt.show()


## Select Style Image

We use the style of the _trencadis_ lizard by Antoni Gaudi  in Parc Guell (Barcelona) in the image `lizard.png`.

In [None]:
style = Image.open("style.png").convert('RGB').resize((256, 256))

In [None]:
plt.imshow(style)
plt.axis('off')
plt.show()

## Load Pretrained Neural Network

We import the VGG16 model, a pretrained neural network known for its proficiency in image recognition tasks, with weights initialized from the ImageNet dataset. We then set the model to evaluation mode and freeze all weights to prevent further changes during our operations.

In [None]:
import torchvision.models as models
from torchvision.models import VGG16_Weights

model = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1)

model.eval()
model.requires_grad_(False)

print(model)

## Implement Style Transfer

We'll define a function that minimize the style and content losses with respect to the reference images. 

We define a funtion for calculating the Gram matrix between all the features of a specific layer.
The Gram matrix represents the correlations between different feature maps (or channels) of the output of a convolutional layer.

This is the `gram()` function and add it to `fnc_style_transfer.py`:

```python
def gram(tensor):
    from torch import bmm
    
    batch_size, num_channels, height, width = tensor.size()
    features = tensor.view(batch_size, num_channels, height * width)
    gram = bmm(features, features.transpose(1, 2)) / (height * width)

    return gram
```

This function:
- Unpacks the dimensions of the input tensor
- Reshapes the tensor so it's a 2D matrix, with channels as features and height*width as observations
- Computes the Gram matrix as the product of the matrix by its transpose, normalizing by the number of elements in each feature map (`height * width`)

We also add the `image_to_tensor()` function to `fnc_style_transfer.py` (the same function as for the DeepDream project).

```python
def image_to_tensor(im, mean, std):
    import torchvision.transforms as tt

    normalize = tt.Compose([tt.ToTensor(), tt.Normalize(mean, std)])

    return normalize(im).unsqueeze(0).requires_grad_(True)
```

We also add the `tensor_to_image()` function to `fnc_style_transfer.py` (the same function as for the DeepDream project).

```python
def tensor_to_image(image, mean, std):
    import torchvision.transforms as tt
    import numpy as np
    from PIL import Image

    denormalize = tt.Normalize(mean=-mean / std, std=1 / std)

    im_array = denormalize(image.data.clone().detach().squeeze()).numpy()
    im_array = np.clip(im_array.transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8)
    return Image.fromarray(im_array, 'RGB')
```

We also add the `fwd_hooks()` function to `fnc_style_transfer.py` (the same function as for the DeepDream project).

```python
class fwd_hooks():
    def __init__(self, layers):
        self.hooks = []
        self.activations_list = []
        for layer in layers:
            self.hooks.append(layer.register_forward_hook(self.hook_func))

    def hook_func(self, layer, input, output):
        self.activations_list.append(output)

    def __enter__(self, *args): 
        return self
    
    def __exit__(self, *args): 
        for hook in self.hooks:
            hook.remove()
```

In [None]:
import numpy as np
import torch
from fnc_style_transfer import gram, image_to_tensor, tensor_to_image, fwd_hooks

def style_transfer(image, content, style, 
                   layers, ind_c, ind_s, 
                   lr=1, beta=1e3, iter_num=100):
    content_layers = [layers[i] for i in ind_c]
    style_layers = [layers[i] for i in ind_s]

    # Normalization parameters typically used with pretrained models
    mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
    std = np.array([0.229, 0.224, 0.225], dtype=np.float32)

    # input image 
    image_tensor = image_to_tensor(image, mean, std)#.requires_grad_(True)
    
    # content
    with fwd_hooks(content_layers) as fh:
        _ = model(image_to_tensor(content, mean, std))
    content_features = [activations.detach() for activations in fh.activations_list]

    # style
    with fwd_hooks(style_layers) as fh:
        _ = model(image_to_tensor(style, mean, std))
    style_features = [activations.detach() for activations in fh.activations_list]
    gram_targets = [gram(s) for s in style_features]

    optimizer = torch.optim.LBFGS([image_tensor], lr=lr)
    mse_loss = torch.nn.MSELoss(reduction="sum")

    for i in range(iter_num):
        def closure():
            optimizer.zero_grad()

            # content
            with fwd_hooks(content_layers) as fh:
                _ = model(image_tensor)
            image_content_features = fh.activations_list

            content_loss = 0
            for icf, cf in zip(image_content_features, content_features):
                n_f = icf.shape[1]
                content_loss += mse_loss(icf, cf) / n_f ** 2
            content_loss /= len(image_content_features)

            # style
            with fwd_hooks(style_layers) as fh:
                _ = model(image_tensor)
            image_style_features = fh.activations_list
            gram_image = [gram(i) for i in image_style_features]

            style_loss = 0
            for gi, gt in zip(gram_image, gram_targets):
                n_g = gi.shape[1]
                style_loss += mse_loss(gi, gt) / n_g ** 2
            style_loss /= len(gram_image)

            print(f"i={i} content_loss={content_loss} style_loss={style_loss}")

            total_loss = content_loss + beta * style_loss
            total_loss.backward()
            return total_loss

        optimizer.step(closure)

        image = tensor_to_image(image_tensor, mean, std)
        
        if i <= 5 or i % 10 == 0:
            plt.imshow(image)
            plt.title(f"Iteration {i}")
            plt.axis("off")
            plt.show()

    return image

We'll apply the style transfer using as an input the same image used to get the content.

In [None]:
def plot_style(content, style, image):
    import matplotlib.pyplot as plt

    plt.figure(figsize=(15, 5)) 
    
    plt.subplot(1, 3, 1)
    plt.imshow(content)
    plt.title('Content image')
    plt.axis('off')
    
    plt.subplot(1, 3, 2)
    plt.imshow(style)
    plt.title('Style image') 
    plt.axis('off')
    
    plt.subplot(1, 3, 3)
    plt.imshow(image)
    plt.title('Output image') 
    plt.axis('off')
    
    plt.show()

In [8]:
#from style_transfer import plot_style

ind = [0, 2, 5, 7, 10, 14]
layers = [model.features[i] for i in ind] 
ind_c = [5]
ind_s = [0, 1, 2, 3, 4]

image_out  =  style_transfer(content, content, style, layers, ind_c, ind_s, lr=1, beta=1e5, iter_num=50)
plot_style(content, style, image_out)

KeyboardInterrupt: 

We'll also apply the style transfer using as an input a random image.

In [None]:
imarray = np.random.rand(256, 256, 3) * 255
image_in = Image.fromarray(imarray.astype('uint8')).convert('RGB')

image_out  =  style_transfer(image_in, content, style, layers, ind_c, ind_s, lr = 1, beta = 1e4, iter_num=50)
plot_style(image_in, style, image_out)