# Neural Style Transfer

We'll describe an implementation of DeepDream.

## Load Data

We'll load the example image of the Drosophila ssTEM dataset from https://figshare.com/articles/dataset/Segmented_anisotropic_ssTEM_dataset_of_neural_tissue/856713. Alternatively, you can download an image from the corresponding GiHub repository: http://github.com/unidesigner/groundtruth-drosophila-vnc.

In [None]:
from PIL import Image

image = Image.open("neuraltissue_with_colorlabels.png").convert('RGB').resize((256, 256))

In [None]:
from matplotlib import pyplot as plt

plt.imshow(image)
plt.axis('off')
plt.show()

We'll use the style of the _trencadis_ lizard by Antoni Gaudi  in Parc Guell (Barcelona)

In [None]:
style = Image.open("style.png").convert('RGB').resize((256, 256))

In [None]:
plt.imshow(style)
plt.axis('off')
plt.show()

We'll define a funtion for calculating the Gram matrix between all the features of a specific layer.

Gramm matrix represents the correlations between different feature maps of a layer

In [None]:
from torch import bmm

def gram_matrix(tensor):
    # Unpack the dimensions of the input tensor
    batch_size, num_channels, height, width = tensor.size()

    # Reshape the tensor so it's a 2D matrix, with channels as features and height*width as observations
    features = tensor.view(batch_size, num_channels, height * width)

    # Compute the Gram matrix as the product of the matrix by its transpose
    # Normalizing by the number of elements in each feature map (height*width)
    gram = bmm(features, features.transpose(1, 2)) / (height * width)

    return gram

We'll download a pretrained model (VGG16) and freeze all the weights.

In [None]:
import torchvision.models as models
from torchvision.models import VGG16_Weights

model = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1)

model.eval()
model.requires_grad_(False)

print(model)

The VGG16 is pretrained on the ImageNet dataset and the inputs are normalized wih respect to the mean and standard deviation of the channels of this dataset.

We'll define a function that minimize the style and content losses with respect to the reference images. 

```python
def image_to_tensor(im, mean, std):
    import torchvision.transforms as tt

    normalize = tt.Compose([tt.ToTensor(), tt.Normalize(mean, std)])

    return normalize(im).unsqueeze(0).requires_grad_(True)
```

```python
def tensor_to_image(image, mean, std):
    import torchvision.transforms as tt
    import numpy as np
    from PIL import Image

    denormalize = tt.Normalize(mean=-mean / std, std=1 / std)

    im_array = denormalize(image.data.clone().detach().squeeze()).numpy()
    im_array = np.clip(im_array.transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8)
    return Image.fromarray(im_array, 'RGB')
```

```python
class fwd_hooks():
    def __init__(self, layers):
        self.hooks = []
        self.activations_list = []
        for layer in layers:
            self.hooks.append(layer.register_forward_hook(self.hook_func))

    def hook_func(self, layer, input, output):
        self.activations_list.append(output)

    def __enter__(self, *args): 
        return self
    
    def __exit__(self, *args): 
        for hook in self.hooks:
            hook.remove()
```

In [None]:
import numpy as np
import torch
import torchvision.transforms as tt
from PIL import Image
from fnc_style_transfer import image_to_tensor, tensor_to_image, fwd_hooks

def style_transfer(im_in, im_c, im_s, layers, ind_c, ind_s, lr = 1, beta = 1e3, iter_num=100):
    # Normalization parameters typically used with pretrained models
    mean_ds = np.array([0.485, 0.456, 0.406], dtype=np.float32)
    std_ds = np.array([0.229, 0.224, 0.225], dtype=np.float32)

    # image
    image_c = image_to_tensor(im_c, mean_ds, std_ds)

    with fwd_hooks(layers) as fh:
        _ = model(image_c)
    content_features = [fh.activations_list[i].detach() for i in ind_c]

    # style
    image_s = image_to_tensor(im_s, mean_ds, std_ds)

    with fwd_hooks(layers) as fh:
        _ = model(image_s)
    style_features = [fh.activations_list[i].detach() for i in ind_s]
    gram_targets = [gram_matrix(s) for s in style_features]

    #input
    try: im_in.verify()
    except: 
        print('Input image not provided. Using a random input.')
        imarray = np.random.rand(256, 256, 3) * 255                             ### GET IMAGE SIZE!
        im_in = Image.fromarray(imarray.astype('uint8')).convert('RGB')
    
    image_in = image_to_tensor(im_in, mean_ds, std_ds)#.requires_grad_(True)
    
    optimizer = torch.optim.LBFGS([image_in], lr=lr)
    mse_loss = torch.nn.MSELoss(reduction='sum')

    l_c = []
    l_s = []

    for it in range(iter_num):
        def closure():
            optimizer.zero_grad()

            # content & style
            with fwd_hooks(layers) as fh:
                out = model(image_in)
            content_features_in = [fh.activations_list[i] for i in ind_c]
            style_features_in = [fh.activations_list[i] for i in ind_s]
            gram_in = [gram_matrix(i) for i in style_features_in]

            c_loss = 0
            for i,c in enumerate(content_features_in):
                n_f = c.shape[1]
                c_loss += mse_loss(c,content_features[i])/n_f**2
            c_loss /= len(content_features_in)

            s_loss = 0
            for i,g in enumerate(gram_in):
                n_g = g.shape[1]
                s_loss += mse_loss(g,gram_targets[i]) / n_g ** 2
            s_loss /= len(gram_in)

            loss = c_loss + beta * s_loss
            l_c.append(c_loss)
            l_s.append(s_loss)

            loss.backward()
            return loss

        optimizer.step(closure)
        print('Step {}: Content Loss: {:.8f} Style Loss: {:.8f}'.format(it, l_c[-1], l_s[-1]))

        # if (it)%10 == 0:
        #     im_out = Image.fromarray(np.uint8(np.clip(deprocess(image_in.data.clone().detach(), mean_ds, std_ds)*255,0,255)), 'RGB') 
        #     plot_style(im_c, im_s, im_out)
                                                                                ###return Image.fromarray(np.uint8(np.clip(deprocess(image_in.data.clone(), mean_ds, std_ds)*255,0,255)), 'RGB')
    return tensor_to_image(image_in, mean_ds, std_ds)

We'll apply the style transfer using as an input the same image used to get the content.

In [None]:
def plot_style(im_c, im_s, im_out):
    import matplotlib.pyplot as plt
    plt.figure(figsize=(15, 5)) 
    plt.subplot(1, 3, 1)
    plt.imshow(im_c)
    plt.title('Content image')
    plt.axis('off')
    plt.subplot(1, 3, 2)
    plt.imshow(im_s)
    plt.title('Style image') 
    plt.axis('off')
    plt.subplot(1, 3, 3)
    plt.imshow(im_out)
    plt.title('Output image') 
    plt.axis('off')
    plt.show()

In [None]:
#from style_transfer import plot_style

ind = [0, 2, 5, 7, 10, 14]
layers = [model.features[i] for i in ind ] 
ind_c = [5]
ind_s = [0, 1, 2, 3, 4]

im_out  =  style_transfer(image, image, style, layers, ind_c, ind_s, lr=1, beta=1e5, iter_num=200)
plot_style(image, style, im_out)

We'll also apply the style transfer using as an input a random image.

In [None]:
im_out  =  style_transfer([], image, style, layers, ind_c, ind_s, lr = 1, beta = 1e4, iter_num=200)
plot_style(image, style, im_out)