In [13]:
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import torch.nn as nn
import torch
import torch.optim as optim
from torchvision import transforms, models
import cv2

In [14]:
vgg = models.vgg19(pretrained=True).features
for param in vgg.parameters():
    param.requires_grad_(False)

In [15]:
def show_image(image_tensor):
    image = image_tensor.to("cpu").clone().detach()
    image = image.numpy().squeeze()
    image = image.transpose(1,2,0)
    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    image = image.clip(0, 1)

    plt.imshow(image)
    
def save_image(image_tensor,i):
    image = image_tensor.to("cpu").clone().detach()
    image = image.numpy().squeeze()
    image = image.transpose(1,2,0)
    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    image = image.clip(0, 1)

    plt.imsave("/home/abdullah/Documents/ai/Style Transfer/results/{}.jpg".format(i),image)
    
    

In [16]:
def load_img(path):
    image = Image.open(path).convert('RGB')
    transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), 
                            (0.229, 0.224, 0.225))])
    
    image = transform(image)[:3,:,:].unsqueeze(0)
    return image

     
def get_features(image, model, layers=None):
    """ Run an image forward through a model and get the features for 
        a set of layers. Default layers are for VGGNet matching Gatys et al (2016)
    """

    # To get the layer outputs we need to pass the image forward through the network 
    # until we get to a desired layer and get the output from that layer
    
    # Mapping layer names of PyTorch's VGGNet to layer names from the paper
    if layers is None:
        layers = {'0': 'conv1_1',
                  '5': 'conv2_1', 
                  '10': 'conv3_1', 
                  '19': 'conv4_1',
                  '21': 'conv4_2',
                  '28': 'conv5_1'}
    features = {}
    x = image
    # model._modules is a dictionary holding each module in the model
    for name, layer in model._modules.items():
        x = layer(x)
        if name in layers:
            features[layers[name]] = x
            
    return features

def gram_mat(tensor):
    b,c,h,w = tensor.size()
    image = tensor.view(b*c,h*w)
    
    matrix = torch.mm(image,image.t())
    
    return matrix

    

In [17]:
vgg.cuda().eval()
content_img = load_img("1.png").cuda()
style_img = load_img("2.jpg").cuda()

content_feature = get_features(content_img, vgg)
style_feature  = get_features(style_img, vgg)

style_gram = {label : gram_mat(style_feature[label]) for label in style_feature}

target_img = content_img.clone().requires_grad_(True).cuda()

loss = nn.MSELoss()
loss1 = nn.MSELoss()

In [None]:
style_weights = {'conv1_1': 1.,
                 'conv2_1': 0.75,
                 'conv3_1': 0.2,
                 'conv4_1': 0.2,
                 'conv5_1': 0.2}

optimizer = optim.Adam([target_img],lr=0.003)

steps = 5000
content_weight = 1
style_weight = 1e6
tv_weight = 1e-6
show_every = 3

for i in range(1,steps + 1):
    style_loss = 0
    target_features = get_features(target_img, vgg)
    print("{} th iteration".format(i))
    content_loss = loss(target_features['conv4_2'],content_feature['conv4_2'])
    
    for layer in style_weights:
        target_feature_ = target_features[layer]
        target_gram  = gram_mat(target_feature_)
        b, c, h, w = target_feature_.shape
        style_loss += style_weights[layer] * torch.mean((target_gram - style_gram[layer])**2) / (c * h * w)
    diff_i = torch.sum(torch.abs(target_img[:, :, :, 1:] - target_img[:, :, :, :-1]))
    diff_j = torch.sum(torch.abs(target_img[:, :, 1:, :] - target_img[:, :, :-1, :]))
    tv_loss = (diff_i + diff_j)
    
    total_loss = content_loss*content_weight + style_loss*style_weight + tv_weight*tv_loss
    
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    if (i%50 == 0):
        save_image(target_img,i)
        show_image(target_img)
        plt.show()
    
    
    
        

1 th iteration
2 th iteration
3 th iteration
4 th iteration
5 th iteration
6 th iteration
7 th iteration
8 th iteration
