In [1]:
import torch
import torch.nn as nn
from torchvision import models

In [2]:
vgg16 = models.vgg16(pretrained=True)

In [3]:
len(list(vgg16.features.children()))

31

In [2]:
class StyleTransferNet():
    def __init__(self):
        vgg16 = models.vgg16(pretrained=True)
        feature_maps = list(vgg16.features.children())
        self.avgpool = nn.AvgPool2d(kernel_size=2,stride=2)
        self.block_1 = nn.Sequential(
            feature_maps[0],
            feature_maps[1],
            feature_maps[2],
            feature_maps[3],
        )
        
        self.block_2 = nn.Sequential(
            feature_maps[5],
            feature_maps[6],
            feature_maps[7],
            feature_maps[8],
        )
        
        self.block_3 = nn.Sequential(
            feature_maps[10],
            feature_maps[11],
            feature_maps[12],
            feature_maps[13],
            feature_maps[14],
            feature_maps[15],
        )
        
        self.block_4 = nn.Sequential(
            feature_maps[17],
            feature_maps[18],
            feature_maps[19],
            feature_maps[20],
            feature_maps[21],
            feature_maps[22],
        )
        
        self.block_5 = nn.Sequential(
            feature_maps[24],
            feature_maps[25],
            feature_maps[26],
            feature_maps[27],
            feature_maps[28],
            feature_maps[29],
        )
        
    def get_features(self,x):
        block_1_features = self.block_1(x)
        x = self.avgpool(block_1_features)
        block_2_features = self.block_2(x)
        x = self.avgpool(block_2_features)
        block_3_features = self.block_3(x)
        x = self.avgpool(block_3_features)
        block_4_features = self.block_4(x)
        x = self.avgpool(block_4_features)
        block_5_features = self.block_5(x)
        return [block_1_features,block_2_features,block_3_features,block_4_features,block_5_features]

In [3]:
net = StyleTransferNet()

In [4]:
import scipy.misc

In [5]:
img = scipy.misc.imread("city.jpg")

In [6]:
w = 500

In [7]:
r = w * 1.0 / img.shape[1] 

In [8]:
img = scipy.misc.imresize(img, (int(img.shape[0] * r),(int(img.shape[1]*r))))

In [9]:
import numpy as np

In [10]:
img = img.astype(np.float32)

In [11]:
def sub_mean(img):
    for i in range(3):
        img[:,:,i] -= mean[i]
        
    return img

In [12]:
mean = np.array([103.939, 116.779, 123.68])

In [13]:
img = sub_mean(img)

In [14]:
img.shape

(288, 500, 3)

In [15]:
from torchvision import transforms

In [16]:
tran = transforms.Compose([transforms.ToTensor()])

In [17]:
img = tran(img)

In [18]:
img = img.resize_(1,img.size()[0],img.size()[1],img.size()[2])

In [19]:
content = torch.autograd.Variable(img,requires_grad=False)

In [20]:
content_features = net.get_features(content)

In [21]:
style = scipy.misc.imread("style.jpg")

In [22]:
r = w * 1.0 / style.shape[1] 

In [23]:
style = scipy.misc.imresize(style, (int(style.shape[0] * r),(int(style.shape[1]*r))))

In [24]:
style = style.astype(np.float32)

In [25]:
style = sub_mean(style)

In [26]:
style = tran(style)
style = style.resize_(1,style.size()[0],style.size()[1],style.size()[2])

In [27]:
style = torch.autograd.Variable(style)

In [28]:
style_features = net.get_features(style)

In [50]:
class ContentLoss(nn.Module):
    def __init__(self, weights):
        super(ContentLoss,self).__init__()
        self.weights = weights
        
    def forward(self, preds, targets):
        self.loss = 0
        for weight, pred, target in zip(self.weights, preds, targets):
            self.loss += torch.mean((pred - target) ** 2) * weight
        return self.loss
    
    def backward(self, retain_graph=True):
        self.loss.backward(retain_graph=retain_graph)
        return self.loss
    
class GramMatrix(nn.Module):
    def forward(self, feature_map):
        a,b,c,d = feature_map.size()
        feature = feature_map.view(a*b, c*d)
        matrix = torch.mm(feature, feature.t())
        
        return matrix.div(a*b*c*d)
    
class StyleLoss(nn.Module):
    def __init__(self, weights):
        super(StyleLoss, self).__init__()
        self.weights = weights
        self.gram = GramMatrix()
        
    def forward(self, preds, targets):
        self.loss = 0
        for weight, pred, target in zip(self.weights, preds, targets):
            g_pred = self.gram(pred)
            g_target = self.gram(target)
            self.loss += torch.mean((g_pred - g_target) ** 2) * weight
            
        return self.loss
    
    def backward(self, retain_graph=True):
        self.loss.backward(retain_graph = retain_graph)
        return self.loss
    
class TotalLoss(nn.Module):
    def __init__(self, content, style):
        super(TotalLoss, self).__init__()
        self.content_loss = content
        self.style_loss = style
        
    def forward(self, gens, contents, styles):
        self.loss = self.content_loss(gens,contents) + self.style_loss(gens,styles)
        return self.loss
    
    def backward(self, retain_graph=True):
        self.loss.backward(retain_graph=retain_graph)
        return self.closs+self.sloss
    

In [51]:
gen = torch.autograd.Variable(img)

In [31]:
gen_features = net.get_features(gen)

In [52]:
content_loss = ContentLoss([0.5,1,1,1,1])

In [33]:
loss = content_loss(gen_features,content_features)

In [34]:
loss.backward()

In [57]:
style_loss = StyleLoss([4000,8000,800,800,800])

In [34]:
loss_style = style_loss(gen_features, style_features)

In [36]:
loss_style

Variable containing:
1.00000e-05 *
  3.5181
[torch.FloatTensor of size 1]

In [58]:
loss_func = TotalLoss(content_loss,style_loss)

In [42]:
loss = loss_func(gen_features,content_features,style_features)

In [38]:
loss.backward()

In [59]:
gen_param = nn.Parameter(gen.data)
opt = torch.optim.LBFGS([gen_param])

In [38]:
moded = nn.Sequential()

In [60]:
for _ in range(0,2):
    def closure():
        opt.zero_grad()
        gen_features = net.get_features(gen_param)
        loss = loss_func(gen_features,content_features,style_features)
        loss.backward(retain_graph=True)
        print(loss)
        return loss
    opt.step(closure)

Variable containing:
1.00000e-02 *
  1.0113
[torch.FloatTensor of size 1]

Variable containing:
1.00000e-02 *
  1.0113
[torch.FloatTensor of size 1]

Variable containing:
1.00000e-02 *
  1.0112
[torch.FloatTensor of size 1]

Variable containing:
1.00000e-02 *
  1.0112
[torch.FloatTensor of size 1]

Variable containing:
1.00000e-02 *
  1.0112
[torch.FloatTensor of size 1]

Variable containing:
inf
[torch.FloatTensor of size 1]



In [61]:
gen

Variable containing:
( 0 , 0 ,.,.) = 
  2.4535e+12  3.6431e+12  4.2900e+12  ...   6.6597e+12  9.2908e+12  1.1243e+12
  5.0289e+12  4.3893e+11 -4.0697e+12  ...   5.4789e+12  1.1203e+13  2.8992e+12
 -4.6099e+12 -1.5890e+13 -2.1428e+13  ...  -1.9816e+12  5.3796e+12  9.0802e+11
                 ...                   ⋱                   ...                
 -4.0767e+12  4.6032e+12  3.1090e+13  ...  -2.9316e+12  8.4038e+12  6.0565e+12
 -9.5824e+12 -9.2737e+12  3.4565e+12  ...   2.7708e+12  6.6660e+12  2.9354e+12
 -5.0812e+12 -7.8063e+12 -5.4116e+12  ...  -4.7237e+12 -1.7901e+12 -2.0716e+12

( 0 , 1 ,.,.) = 
 -5.4878e+12 -6.3358e+12 -4.1210e+12  ...  -2.8581e+12 -4.2451e+12 -8.6305e+12
 -4.9304e+11 -3.0907e+12 -2.6643e+12  ...   2.6028e+12 -3.2325e+12 -9.6822e+12
 -6.9339e+12 -8.5107e+12 -1.6484e+12  ...   1.7408e+13  3.6198e+12 -7.0220e+12
                 ...                   ⋱                   ...                
  1.0108e+13 -4.3965e+12 -2.6214e+13  ...  -3.1285e+13 -2.1247e+13 -3.1646e

In [62]:
content

Variable containing:
( 0 , 0 ,.,.) = 
  2.4535e+12  3.6431e+12  4.2900e+12  ...   6.6597e+12  9.2908e+12  1.1243e+12
  5.0289e+12  4.3893e+11 -4.0697e+12  ...   5.4789e+12  1.1203e+13  2.8992e+12
 -4.6099e+12 -1.5890e+13 -2.1428e+13  ...  -1.9816e+12  5.3796e+12  9.0802e+11
                 ...                   ⋱                   ...                
 -4.0767e+12  4.6032e+12  3.1090e+13  ...  -2.9316e+12  8.4038e+12  6.0565e+12
 -9.5824e+12 -9.2737e+12  3.4565e+12  ...   2.7708e+12  6.6660e+12  2.9354e+12
 -5.0812e+12 -7.8063e+12 -5.4116e+12  ...  -4.7237e+12 -1.7901e+12 -2.0716e+12

( 0 , 1 ,.,.) = 
 -5.4878e+12 -6.3358e+12 -4.1210e+12  ...  -2.8581e+12 -4.2451e+12 -8.6305e+12
 -4.9304e+11 -3.0907e+12 -2.6643e+12  ...   2.6028e+12 -3.2325e+12 -9.6822e+12
 -6.9339e+12 -8.5107e+12 -1.6484e+12  ...   1.7408e+13  3.6198e+12 -7.0220e+12
                 ...                   ⋱                   ...                
  1.0108e+13 -4.3965e+12 -2.6214e+13  ...  -3.1285e+13 -2.1247e+13 -3.1646e

In [48]:
torch.mean((gen - gen_param) ** 2)

Variable containing:
 0
[torch.FloatTensor of size 1]

In [63]:
gen_param

Parameter containing:
( 0 , 0 ,.,.) = 
  2.4535e+12  3.6431e+12  4.2900e+12  ...   6.6597e+12  9.2908e+12  1.1243e+12
  5.0289e+12  4.3893e+11 -4.0697e+12  ...   5.4789e+12  1.1203e+13  2.8992e+12
 -4.6099e+12 -1.5890e+13 -2.1428e+13  ...  -1.9816e+12  5.3796e+12  9.0802e+11
                 ...                   ⋱                   ...                
 -4.0767e+12  4.6032e+12  3.1090e+13  ...  -2.9316e+12  8.4038e+12  6.0565e+12
 -9.5824e+12 -9.2737e+12  3.4565e+12  ...   2.7708e+12  6.6660e+12  2.9354e+12
 -5.0812e+12 -7.8063e+12 -5.4116e+12  ...  -4.7237e+12 -1.7901e+12 -2.0716e+12

( 0 , 1 ,.,.) = 
 -5.4878e+12 -6.3358e+12 -4.1210e+12  ...  -2.8581e+12 -4.2451e+12 -8.6305e+12
 -4.9304e+11 -3.0907e+12 -2.6643e+12  ...   2.6028e+12 -3.2325e+12 -9.6822e+12
 -6.9339e+12 -8.5107e+12 -1.6484e+12  ...   1.7408e+13  3.6198e+12 -7.0220e+12
                 ...                   ⋱                   ...                
  1.0108e+13 -4.3965e+12 -2.6214e+13  ...  -3.1285e+13 -2.1247e+13 -3.1646