# pytorch implement artistic style实现

In [1]:
import os
import random
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
from PIL import Image
import numpy as np
import torchvision.models as models
from vgg import VGG
import util

In [2]:
cuda = torch.cuda.is_available()

In [3]:
manualSeed = random.randint(1, 10000)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
if cuda:
    torch.cuda.manual_seed_all(opt.manualSeed)

In [4]:
cudnn.benchmark = True

In [5]:
imageSize = 128

In [6]:
transform = transforms.Compose([
    transforms.Scale(imageSize),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])]), #turn to BGR
    transforms.Normalize(mean=[0.40760392, 0.45795686, 0.48501961],std=[1,1,1]),
    transforms.Lambda(lambda x: x.mul_(255)),
    ])
def load_image(path,style=False):
    img = Image.open(path)
    img = Variable(transform(img))
    img = img.unsqueeze(0)
    return img

def save_image(img):
    post = transforms.Compose([transforms.Lambda(lambda x: x.mul_(1./255)),
         transforms.Normalize(mean=[-0.40760392, -0.45795686, -0.48501961], std=[1,1,1]),
         transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])]), #turn to RGB
         ])
    img = post(img)
    img = img.clamp_(0,1)
    vutils.save_image(img,
                'images/transfer.png',
                normalize=True)
    return

In [7]:
style_image = 'images/starry_night.jpg'
content_image = 'images/blue-moon-lake.jpg'

In [8]:
styleImg = load_image(style_image) # 1x3x512x512
contentImg = load_image(content_image) # 1x3x512x512
if(cuda):
    styleImg = styleImg.cuda()
    contentImg = contentImg.cuda()

In [9]:
vgg = VGG()
vgg.load_state_dict(torch.load('models/vgg_conv.pth'))
for param in vgg.parameters():
    param.requires_grad = False
if(cuda):
    vgg.cuda()

In [10]:
class GramMatrix(nn.Module):
    def forward(self,input):
        # b=1
        b, c, h, w = input.size()
        f = input.view(b,c,h*w)
        G = torch.bmm(f,f.transpose(1,2))
        # return G.div_(h*w)
        return G.div_(b*c*h*w)

In [11]:
class styleLoss(nn.Module):
    def forward(self,input,target):
        GramInput = GramMatrix()(input)
        return nn.MSELoss()(GramInput,target)

In [12]:
style_layers = 'r11,r21,r31,r41,r51'.split(',')
content_layers = 'r42'.split(',')

In [13]:
styleTargets = []
for t in vgg(styleImg,style_layers):
    t = t.detach()
    styleTargets.append(GramMatrix()(t))

In [14]:
contentTargets = []
for t in vgg(contentImg,content_layers):
    t = t.detach()
    contentTargets.append(t)

In [15]:
styleLosses = [styleLoss()] * len(style_layers)
contentLosses = [nn.MSELoss()] * len(content_layers)

In [16]:
losses = styleLosses + contentLosses
targets = styleTargets + contentTargets

In [17]:
loss_layers = style_layers + content_layers

In [18]:
content_weight = 0.05
style_weight = 1
weights = [style_weight]*len(style_layers) + [content_weight]*len(content_layers)

In [19]:
optImg = Variable(contentImg.data.clone(), requires_grad=True)
optimizer = optim.LBFGS([optImg])

In [20]:
if(cuda):
    for loss in losses:
        loss = loss.cuda()
    optImg.cuda()

In [21]:
epoches = 100
for iteration in range(1,epoches):
    print('Iteration [%d]/[%d]'%(iteration,epoches))
    def closure():
        optimizer.zero_grad()
        out = vgg(optImg,loss_layers)
        totalLossList = []
        for i in range(len(out)):
            layer_output = out[i]
            loss_i = losses[i]
            target_i = targets[i]
            totalLossList.append(loss_i(layer_output,target_i) * weights[i])
        totalLoss = sum(totalLossList)
        totalLoss.backward()
        print('loss: %f'%(totalLoss.data[0]))
        return totalLoss
    optimizer.step(closure)
outImg = optImg.data[0].cpu()
save_image(outImg.squeeze())

Iteration [1]/[100]
loss: 893894.062500
loss: 893888.500000
loss: 878339.937500
loss: 1140318.750000
loss: 460211.187500
loss: 282584.531250
loss: 194567.703125
loss: 152253.921875
loss: 133491.500000
loss: 118156.289062
loss: 106765.359375
loss: 98230.078125
loss: 92686.203125
loss: 84696.750000
loss: 81274.234375
loss: 78425.968750
loss: 74872.843750
loss: 71994.671875
loss: 69312.781250
loss: 66778.812500
Iteration [2]/[100]
loss: 64614.464844
loss: 61904.535156
loss: 60294.085938
loss: 58975.414062
loss: 57844.710938
loss: 56914.179688
loss: 55660.039062
loss: 54590.132812
loss: 53647.710938
loss: 52877.617188
loss: 51615.390625
loss: 50809.875000
loss: 49817.976562
loss: 49381.976562
loss: 48872.011719
loss: 48143.714844
loss: 47630.132812
loss: 47073.632812
loss: 46429.968750
loss: 45764.914062
Iteration [3]/[100]
loss: 45315.539062
loss: 44888.753906
loss: 44376.468750
loss: 43901.992188
loss: 43565.062500
loss: 43215.585938
loss: 42991.023438
loss: 42531.421875
loss: 42292.4062

loss: 25600.800781
loss: 25597.500000
loss: 25594.544922
loss: 25592.312500
loss: 25590.212891
loss: 25587.630859
loss: 25584.007812
loss: 25580.539062
loss: 25577.134766
loss: 25574.333984
loss: 25572.460938
loss: 25570.121094
Iteration [22]/[100]
loss: 25567.625000
loss: 25564.585938
loss: 25561.347656
loss: 25558.130859
loss: 25555.160156
loss: 25553.304688
loss: 25550.824219
loss: 25548.421875
loss: 25545.630859
loss: 25542.681641
loss: 25540.695312
loss: 25538.101562
loss: 25536.054688
loss: 25534.025391
loss: 25531.869141
loss: 25529.462891
loss: 25526.904297
loss: 25523.699219
loss: 25521.398438
loss: 25519.152344
Iteration [23]/[100]
loss: 25516.718750
loss: 25514.554688
loss: 25512.195312
loss: 25510.734375
loss: 25507.863281
loss: 25505.089844
loss: 25502.404297
loss: 25500.396484
loss: 25498.128906
loss: 25496.339844
loss: 25494.076172
loss: 25492.337891
loss: 25490.259766
loss: 25488.279297
loss: 25486.138672
loss: 25483.339844
loss: 25481.175781
loss: 25479.681641
loss: 25

loss: 25098.933594
loss: 25098.460938
Iteration [42]/[100]
loss: 25098.005859
loss: 25097.501953
loss: 25097.050781
loss: 25096.605469
loss: 25096.130859
loss: 25095.777344
loss: 25095.140625
loss: 25094.876953
loss: 25094.402344
loss: 25094.039062
loss: 25093.625000
loss: 25092.996094
loss: 25092.615234
loss: 25092.093750
loss: 25091.757812
loss: 25091.281250
loss: 25090.859375
loss: 25090.320312
loss: 25089.847656
loss: 25089.628906
Iteration [43]/[100]
loss: 25089.324219
loss: 25088.855469
loss: 25088.191406
loss: 25087.851562
loss: 25087.410156
loss: 25087.035156
loss: 25086.460938
loss: 25086.144531
loss: 25085.390625
loss: 25085.027344
loss: 25084.302734
loss: 25083.925781
loss: 25083.431641
loss: 25083.326172
loss: 25082.865234
loss: 25082.277344
loss: 25081.765625
loss: 25081.347656
loss: 25080.617188
loss: 25080.199219
Iteration [44]/[100]
loss: 25079.917969
loss: 25079.310547
loss: 25078.949219
loss: 25078.542969
loss: 25077.886719
loss: 25077.494141
loss: 25076.953125
loss: 

loss: 24968.128906
loss: 24967.880859
loss: 24967.656250
loss: 24967.615234
loss: 24967.550781
loss: 24967.281250
loss: 24967.105469
loss: 24966.832031
loss: 24966.785156
loss: 24966.343750
loss: 24966.226562
loss: 24966.050781
loss: 24965.658203
loss: 24965.689453
Iteration [63]/[100]
loss: 24965.419922
loss: 24965.255859
loss: 24965.021484
loss: 24964.972656
loss: 24964.689453
loss: 24964.460938
loss: 24964.312500
loss: 24964.136719
loss: 24963.726562
loss: 24963.699219
loss: 24963.486328
loss: 24963.328125
loss: 24963.119141
loss: 24962.808594
loss: 24962.460938
loss: 24962.367188
loss: 24962.232422
loss: 24961.927734
loss: 24961.849609
loss: 24961.646484
Iteration [64]/[100]
loss: 24961.464844
loss: 24961.126953
loss: 24960.857422
loss: 24960.755859
loss: 24960.486328
loss: 24960.357422
loss: 24960.089844
loss: 24959.876953
loss: 24959.755859
loss: 24959.521484
loss: 24959.306641
loss: 24959.076172
loss: 24959.083984
loss: 24958.750000
loss: 24958.449219
loss: 24958.312500
loss: 24

loss: 24911.974609
loss: 24911.960938
loss: 24911.960938
Iteration [84]/[100]
loss: 24911.960938
loss: 24911.871094
loss: 24911.777344
loss: 24911.839844
loss: 24911.625000
loss: 24911.464844
loss: 24911.365234
loss: 24911.210938
loss: 24911.093750
loss: 24911.054688
loss: 24910.925781
loss: 24910.785156
loss: 24910.656250
loss: 24910.623047
loss: 24910.322266
loss: 24910.337891
loss: 24910.238281
loss: 24910.195312
loss: 24910.074219
loss: 24910.031250
Iteration [85]/[100]
loss: 24909.960938
loss: 24909.972656
loss: 24909.843750
loss: 24909.728516
loss: 24909.701172
loss: 24909.738281
loss: 24909.531250
loss: 24909.472656
loss: 24909.308594
loss: 24909.148438
loss: 24909.095703
loss: 24909.203125
loss: 24909.246094
loss: 24908.972656
loss: 24908.935547
loss: 24908.873047
loss: 24908.703125
loss: 24908.656250
loss: 24908.628906
loss: 24908.527344
Iteration [86]/[100]
loss: 24908.427734
loss: 24908.365234
loss: 24908.275391
loss: 24908.214844
loss: 24908.103516
loss: 24907.949219
loss: 