# Compute the style of an image

Based on the paper by Gatis et al. 2016 ([ref](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Gatys_Image_Style_Transfer_CVPR_2016_paper.pdf)).


In [1]:
# Imports
import torch
import torchvision
from torch import nn
import skimage
from skimage import transform
from im_func import show_image, timer
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output

ModuleNotFoundError: ignored

# Get images

In [None]:
# content_im = skimage.data.coffee()
content_im = skimage.io.imread("./Data/Cecile.jpg")
style_im = skimage.io.imread("./Data/PICASSO_WOMAN.jpeg")
# im = skimage.transform.rescale(im, 0.25, anti_aliasing=False, multichannel=True)

fig, ax = plt.subplots(1,2,figsize=[10,5])
plt.sca(ax[0])
_ = show_image(content_im,'content')
plt.sca(ax[1])
_ = show_image(style_im,'style')
# im.max()
# im = im/255.
# im = im.astype(np.float32)

# Pre/post processing of image

Normalization, resizing etc... to correspond to the input format of images trained by t

In [None]:
# From D2L
im_shape = (150, 225)
rgb_mean = torch.tensor([0.485, 0.456, 0.406])
rgb_std = torch.tensor([0.229, 0.224, 0.225])

def preprocess(img, image_shape):
    transforms = torchvision.transforms.Compose([
        torchvision.transforms.ToPILImage(),
        torchvision.transforms.Resize(image_shape),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)
    ])
    return transforms(img).unsqueeze(0)

def postprocess(img):
    img = img[0].to(rgb_std.device)
    img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1)
    return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))

# Get a pretrained model

In [2]:
pretrained_net = torchvision.models.vgg19(pretrained=True)

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


HBox(children=(FloatProgress(value=0.0, max=574673361.0), HTML(value='')))




In [3]:
pretrained_net.features[0]

Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

# Defined the image generator network and a function to make a partial pass through the network

In [None]:
class ImageGenerator(nn.Module):
    def __init__(self,shape):
        super(ImageGenerator,self).__init__()
        self.im = nn.Parameter(torch.rand(1, 3, *shape))
    def forward(self):
        return self.im

In [None]:
# Apply layer by layer mods
def partial_forward(input_im,n_layer):
    X = input_im
    for il in range(n_layer):
        X = pretrained_net.features[il](X)
    return X

def visu_im_rep(X, ncol=5):
    nrow = int(np.ceil(n_channels/ncol))
    fig, ax = plt.subplots(nrow,ncol,figsize=[15,24],tight_layout=True)
    with torch.no_grad():
        for ic in range(n_channels):
            plt.sca(ax[np.unravel_index(ic,(nrow,ncol))])
            show_image(X[0,ic,:,:])

`x` is the representation of our input image by the given layer. This representation is composed of many channels, each the result of a specific convolution that has been optimized through training to extract specific useful features. 
Now, we instantiate the simple ImageGenerator. `im_gen` parameters' (`list(im_gen.parameters())[0]`) contains a tensor initialized with random noise. 

In [None]:
# len(pretrained_net.features)
pretrained_net.named_parameters

<bound method Module.named_parameters of VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256

# Compute style

In [None]:


def get_feature_maps(image, layer_list):
#      return [pretrained_net.features[il](image) for il in range (layer_list[-1]+1) if (il in layer_list) ]
    a = image 
    feature_maps = []
    for il in range(layer_list[-1]+1):
#         a = pretrained_net.features[il](a)
        a = net[il](a)
        if (il in layer_list):
            feature_maps.append(a)
    return feature_maps

def get_gram(feature_maps):
    gram_matrices = []
    for a in feature_maps:
        fac = 1./(2.*a.shape[1]*a.shape[2]*a.shape[3])
        a = a.reshape(a.shape[1],-1) 
        gram_matrices.append(torch.matmul(a,a.T)*fac)

    return gram_matrices

def get_style_loss(gram_style_image, gram_generated_image):
    loss = 0
    for A, G in zip(gram_style_image,gram_generated_image):
        loss += torch.sum((A-G)**2)
    return loss

def get_content_loss(feature_map_content, feature_map_gen):
    loss = 0
    for i in range(len(feature_map_content)):
         loss += torch.mean((feature_map_content[i]-feature_map_gen[i])**2)
    return loss


In [None]:
def get_TV_loss(im):
    # y: filtered representation of content image on the given layer of the NN
    # y_hat: same as y for generated image
    # im: the generated image
    
    beta = 1. # Can be between 1 and 2
    H = im.shape[2]
    W = im.shape[3]
    C = im.shape[1]
    sigma = torch.sqrt(torch.sum(im**2))/H/W/C
    B = 1. # because images are standardized
    a = 0.01 # 1%
    Lambda_b = sigma**beta / (H*W*(a*B)**beta)# There is a better definition in the paper
    
    # total variation
    d_dx = im[:,:,1:,:]-im[:,:,:-1,:]
    d_dy = im[:,:,:,1:]-im[:,:,:,:-1]
    d_dx = .5*(d_dx[:,:,:,1:]+d_dx[:,:,:,:-1])
    d_dy = .5*(d_dy[:,:,1:,:]+d_dy[:,:,:-1,:])
    TV = torch.sum((d_dx**2+d_dy**2)**(beta/2.))
#     print(f"MSE: {loss:.2e}, TV: {TV:.2e}, $\\lambda$ TV: {Lambda_b*TV:.2e}")
    loss = Lambda_b*TV
    return loss

# Initialize GPU if available

In [None]:
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

Using device: cpu



# Train

In [None]:
im_gen = ImageGenerator(im_shape).to(device)
im_style = preprocess(style_im, im_shape).to(device)
im_content = preprocess(content_im, im_shape).to(device)

optim = torch.optim.LBFGS(im_gen.parameters())
# optim = torch.optim.Adam(im_gen.parameters())
# Stopping criterion
abs_loss_limit = 1e-10
rel_loss_limit = 1e-12


# layer_list = [8,17, 26, 35]
# layer_list = [0,1,2,3,4,5]
style_layer_list = [8,17, 26, 35]
content_layer = [8]

net = nn.Sequential(*[
    pretrained_net.features[i]
    for i in range(max(content_layer + style_layer_list) + 1)]).to(device)

with torch.no_grad():
    fm_style = get_feature_maps(im_style, style_layer_list)
    fm_content = get_feature_maps(im_content, content_layer)
    gram_style = get_gram(fm_style)
content_weight = 0.06
style_weight = 1.0
tv_weight = 1.0

In [None]:
# train
fig, ax = plt.subplots(1,1,figsize=[10,10])

def closure_small():

    im_generated = im_gen()
#     fm_style = get_feature_maps(im_style.detach(), style_layer_list)
#     gram_style = get_gram(fm_style)
    fm_gen_style = get_feature_maps(im_generated, style_layer_list)
    fm_gen_content = get_feature_maps(im_generated, content_layer) # having to go through the network again is a bit ugly
    gram_gen = get_gram(fm_gen_style)
    content_loss = get_content_loss(fm_content, fm_gen_content)
    style_loss = get_style_loss(gram_style,gram_gen)
    tv_loss = get_TV_loss(im_generated)
    
    return content_loss, style_loss, tv_loss
def closure():
    optim.zero_grad()
    content_loss, style_loss, tv_loss = closure_small()
    loss = content_weight*content_loss + style_weight*style_loss + tv_weight*tv_loss
    print(f"content, style, tv losses: {content_weight*content_loss:.2e}, {style_weight*style_loss:.2e}, {tv_weight*tv_loss:.2e}")
    loss.backward()
    return loss

last_loss = 1e10
for i in range(100):    
#     im_generated = im_gen()
    optim.step(closure)
    if i%1==0:
        with torch.no_grad():
            content_loss, style_loss, tv_loss = closure_small()
            loss = content_weight*content_loss + style_weight*style_loss + tv_weight*tv_loss
            imnew = postprocess(im_gen()[0])
            ax.cla()
            plt.imshow(imnew)
            plt.title(f"epoch {i:02d}, content, style, tv losses: {content_weight*content_loss:.2e}, {style_weight*style_loss:.2e}, {tv_weight*tv_loss:.2e}")
            clear_output(wait = True)
            display(fig)

            if loss<abs_loss_limit:
                clear_output(wait = True)
                print(f'success: absolute loss limit ({abs_loss_limit:.1e}) reached')
                break
            if torch.abs(last_loss-loss)<rel_loss_limit:
                clear_output(wait = True)
                print(f'stopped because relative loss limit ({rel_loss_limit:.1e})  was reached')
                break
                
            last_loss = loss