In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


from PIL import Image
import matplotlib.pyplot as plt

import torchvision.transforms as transforms
from torchvision.models import vgg19, VGG19_Weights

import copy

import utils.utils as utils

import os
import time

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_device(device)

In [3]:
def gram_matrix(input):
    a, b, c, d = input.size()  # a=batch size(=1)
    # b=number of feature maps
    # (c,d)=dimensions of a f. map (N=c*d)

    features = input.view(a * b, c * d)  # resize F_XL into \hat F_XL

    G = torch.mm(features, features.t())  # compute the gram product

    # we 'normalize' the values of the gram matrix
    # by dividing by the number of element in each feature maps.
    return G.div(a * b * c * d)


In [4]:
class LossNetwork(torch.nn.Module):
    def __init__(self):
        super().__init__()

        cnn = vgg19(weights=VGG19_Weights.DEFAULT).features.eval()
        

        self.feature_maps = dict()
        self.feature_maps['relu1_2'] = nn.Sequential()
        self.feature_maps['relu2_2'] = nn.Sequential()
        self.feature_maps['relu3_4'] = nn.Sequential()
        self.feature_maps['relu4_4'] = nn.Sequential()
        self.feature_maps['relu5_4'] = nn.Sequential()

        names = ['relu1_2', 'relu2_2', 'relu3_4', 'relu4_4', 'relu5_4']
        
        i = 1
        j = 1

        for layer in cnn.children():

            if isinstance(layer, nn.MaxPool2d):
                name = 'pool_{}'.format(i)
                i += 1
                j = 1
            elif isinstance(layer, nn.Conv2d):
                name = 'conv{}_{}'.format(i,j)
            elif isinstance(layer, nn.ReLU):
                name = 'relu{}_{}'.format(i,j)
                j += 1

            if i == 6:
                break
            
            self.feature_maps[names[i-1]].add_module(name,layer)
        
        for param in self.parameters():
                param.requires_grad = False


    def forward(self,x):
            
        x = self.feature_maps['relu1_2'](x)
        feature_map1 = x
        x = self.feature_maps['relu2_2'](x)
        feature_map2 = x
        x = self.feature_maps['relu3_4'](x)
        feature_map3 = x
        x = self.feature_maps['relu4_4'](x)
        feature_map4 = x
        x = self.feature_maps['relu5_4'](x)
        feature_map5 = x

        return feature_map1, feature_map2, feature_map3, feature_map4, feature_map5

In [20]:
def NST(content_path, style_path, input_path, content_width, num_steps=300, content_weight=1e6, style_weight=1):

    loss_net = LossNetwork().to(device)

    content_image = utils.prepare_img(content_path, content_width, device)
    style_image = utils.prepare_img(style_path, None, device)

    content_features = loss_net(content_image)
    style_features = loss_net(style_image)

    style_grams = []
    for feature_map in style_features:
        style_grams.append(gram_matrix(feature_map))

    if input_path == None:
        input_image = content_image.clone()

    input_image.requires_grad_(True)

    loss_net.eval()
    loss_net.requires_grad_(False)

    optimizer = optim.LBFGS([input_image])

    for step in range(num_steps):

        def closure():   
            
            # clear gradients 
            optimizer.zero_grad()  

            # Feed input image to loss net (VGG19)
            input_features = loss_net(input_image)

            # Calculate content loss
            content_target = content_features[3]
            input_content = input_features[3]
            content_loss = F.mse_loss(content_target, input_content, reduction='mean')

            # Calculate style loss
            style_loss = 0.0

            input_grams = []
            for feature_map in input_features:
                input_grams.append(gram_matrix(feature_map))
        
            for k in range(5):
                style_loss += F.mse_loss(style_grams[k], input_grams[k], reduction='mean')
            
            # Combine losses and do a backprop
            total_loss = content_weight*content_loss + style_weight*style_loss 
            total_loss.backward(retain_graph=True)

            
            print('step {}'.format((step)))
            print('loss: {}'.format(total_loss))

            return total_loss
        
        optimizer.step(closure)

        

    stylized_image = utils.post_process_image(input_image)
    plt.imshow(stylized_image)
    plt.show()



In [21]:
content_path = "./images/dancing.jpg"
style_path = "./images/vg_starry_night.jpg"
input_path = None 
content_width = 128 


NST(content_path, style_path, input_path, content_width)

step 0
loss: 0.003850937355309725
step 0
loss: 0.0039023817516863346
step 0
loss: 0.0038512449245899916
step 0
loss: 0.0038513473700731993
step 0
loss: 0.003851355519145727


KeyboardInterrupt: 