In [1]:
import torch
import torch.nn as nn
import torchvision
from torchvision import models,transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [15]:
def get_image(path, img_transform, size = (300,300)):
    image = Image.open(path)
    image = image.resize(size, Image.LANCZOS)
    image = img_transform(image).unsqueeze(0)
    return image.to(device)

def get_gram(m):
    _, c, h, w = m.size()
    m = m.view(c, h * w)
    m = torch.mm(m, m.t())
    return m

def denormalize_img(inp):
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    return inp

In [16]:
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        self.selected_layers = [3, 8, 15, 22]
        self.vgg = models.vgg16(pretrained=True).features

    def forward(self, x):
        layer_features = []
        for layer_number, layer in self.vgg._modules.items():
            x = layer(x)
            if int(layer_number) in self.selected_layers:
                layer_features.append(x)
        return layer_features

In [18]:
img_transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])

content_img = get_image('content.jpg', img_transform)
style_img = get_image('style.jpg', img_transform)
generated_img = content_img.clone()    # or nn.Parameter(torch.FloatTensor(content_img.size()))
generated_img.requires_grad = True

optimizer = torch.optim.Adam([generated_img], lr=0.003, betas=[0.5, 0.999])
encoder = FeatureExtractor().to(device)

for p in encoder.parameters():
    p.requires_grad = False



In [19]:
content_weight=1
style_weight=100
epochs=500

for epoch in epochs:
    content_features=encoder(content_img)
    style_features=encoder(style_img)
    generated_features = encoder(generated_img)

    content_loss=torch.mean((content_features[-1]-generated_features[-1])**2)
    style_loss =0
    for gf,sf in zip(generated_features,style_features):
        _,c,h,w=gf.size()
        gram_gf = get_gram(gf)
        gram_sf=get_gram(sf)
        style_loss+=torch.mean((gram_gf - gram_sf)**2)/(c*h*w)

    loss=content_weight * content_loss +style_weight*style_loss
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 10 ==0:
        print('Epoch [{}/{}]\tContent Loss: {:.4f}\tStyle Loss: {:.4f}'.format(epoch,epochs,content_loss.item(),style_loss.item()))


Epoch [0]	Content Loss: 0.0000	Style Loss: 6196.9077
Epoch [10]	Content Loss: 0.4555	Style Loss: 5835.3252
Epoch [20]	Content Loss: 0.8173	Style Loss: 5372.0059
Epoch [30]	Content Loss: 1.0521	Style Loss: 4937.8672
Epoch [40]	Content Loss: 1.1999	Style Loss: 4553.6685
Epoch [50]	Content Loss: 1.3045	Style Loss: 4184.9644
Epoch [60]	Content Loss: 1.3787	Style Loss: 3818.4541
Epoch [70]	Content Loss: 1.4327	Style Loss: 3457.6013
Epoch [80]	Content Loss: 1.4767	Style Loss: 3109.7593
Epoch [90]	Content Loss: 1.5154	Style Loss: 2784.1450
Epoch [100]	Content Loss: 1.5485	Style Loss: 2489.3018
Epoch [110]	Content Loss: 1.5781	Style Loss: 2231.2903
Epoch [120]	Content Loss: 1.6065	Style Loss: 2011.9043
Epoch [130]	Content Loss: 1.6374	Style Loss: 1828.3248
Epoch [140]	Content Loss: 1.6700	Style Loss: 1674.6689
Epoch [150]	Content Loss: 1.7030	Style Loss: 1544.1252
Epoch [160]	Content Loss: 1.7339	Style Loss: 1430.4858
Epoch [170]	Content Loss: 1.7657	Style Loss: 1328.8231
Epoch [180]	Content L