In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision
from torchvision import transforms, models
from PIL import Image
import argparse
import numpy as np
import os

In [4]:
def load_image(image_path, transforms=None, max_size=None, shape=None):
    image = Image.open(image_path)
    image_size = image.size

    if max_size is not None:
        #获取图像size，为sequence
        image_size = image.size
        #转化为float的array
        size = np.array(image_size).astype(float)
        size = max_size / size * size
        image = image.resize(size.astype(int), Image.ANTIALIAS)

    if shape is not None:
        image = image.resize(shape, Image.LANCZOS) #LANCZOS也是一种插值方法

    #必须提供transform.ToTensor，转化为4D Tensor
    if transforms is not None:
        image = transforms(image).unsqueeze(0)

    #是否拷贝到GPU
    return image.type(dtype)

In [6]:
#定义加载图像函数，并将PIL image转化为Tensor
use_gpu = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if use_gpu else torch.FloatTensor
print(dtype)

<class 'torch.cuda.FloatTensor'>


<img src="https://i.loli.net/2019/04/10/5cadd5db873bd.png">


In [10]:
class VGGNet(nn.Module):
    def __init__(self):
        super(VGGNet, self).__init__()
        self.select = ['0', '5', '10', '19', '28']
        self.vgg19 = models.vgg19(pretrained = True).features

    def forward(self, x):
        features = []
        #name类型为str，x为Variable
        for name, layer in self.vgg19._modules.items():
            x = layer(x)
            if name in self.select:
                features.append(x)
        return features

In [7]:

transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), 
                             (0.229, 0.224, 0.225))
        ])

content = load_image('G:/warehouse/DL2019/hw4/content.jpg', transform, max_size = 400)
style = load_image('G:/warehouse/DL2019/hw4/style.jpg', transform, shape = [content.size(2), content.size(3)])
print(content.size())
print(style.size())

torch.Size([1, 3, 400, 400])
torch.Size([1, 3, 400, 400])


In [8]:
target = Variable(content.clone(), requires_grad = True)
optimizer = torch.optim.Adam([target], lr = 0.002, betas=[0.5, 0.999])


In [11]:
vgg = VGGNet()

In [None]:
torch.version.cuda

In [9]:
torch.cuda.is_available()

True

In [2]:
torch.cuda.current_device()

0

In [11]:
parser = argparse.ArgumentParser()
parser.add_argument('--content', type=str, default='./content.jpg')
parser.add_argument('--style', type=str, default='./style.jpg')
parser.add_argument('--max_size', type=int, default=400)
parser.add_argument('--total_step', type=int, default=5000)
parser.add_argument('--log_step', type=int, default=10)
parser.add_argument('--sample_step', type=int, default=100)
parser.add_argument('--style_weight', type=float, default=100)
parser.add_argument('--lr', type=float, default=0.003)


_StoreAction(option_strings=['--lr'], dest='lr', nargs=None, const=None, default=0.003, type=<class 'float'>, choices=None, help=None, metavar=None)

In [13]:
if use_gpu:
    vgg = vgg.cuda()
#config.log_step=10
#config.sample_step=500
for step in range(1000):  
    target_features = vgg(target)
    content_features = vgg(Variable(content))
    style_features = vgg(Variable(style))
    content_loss = 0.0
    style_loss = 0.0
    for f1,f2,f3 in zip(target_features, content_features, style_features):
        content_loss += torch.mean((f1 - f2)**2)
        n, c, h, w = f1.size()
    #将特征reshape成二维矩阵相乘，求gram矩阵
        f1 = f1.view(c, h * w)
        f3 = f3.view(c, h * w)

        f1 = torch.mm(f1, f1.t())
        f3 = torch.mm(f3, f3.t())
    
    #计算style_loss
        style_loss += torch.mean((f1 - f3)**2) / (c * h * w)
        
    loss = content_loss + style_loss 
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if (step+1) % 10 == 0:
            print ('Step [%d/%d], Content Loss: %.4f, Style Loss: %.4f'
                   %(step+1, 1000, content_loss.data, style_loss.data))

    if (step+1) % 1000 == 0:
            # Save the generated image
        denorm = transforms.Normalize((-2.12, -2.04, -1.80), (4.37, 4.46, 4.44))
        img = target.clone().cpu().squeeze()
        img = denorm(img.data).clamp_(0, 1)
        torchvision.utils.save_image(img, 'output-%d.png' %(step+1))
    



Step [10/1000], Content Loss: 1.8715, Style Loss: 1073.7605
Step [20/1000], Content Loss: 5.8217, Style Loss: 957.3817
Step [30/1000], Content Loss: 9.4461, Style Loss: 868.0582
Step [40/1000], Content Loss: 12.2200, Style Loss: 801.0866
Step [50/1000], Content Loss: 14.3485, Style Loss: 748.8475
Step [60/1000], Content Loss: 16.0742, Style Loss: 706.4912
Step [70/1000], Content Loss: 17.4831, Style Loss: 671.1323
Step [80/1000], Content Loss: 18.6602, Style Loss: 640.8476
Step [90/1000], Content Loss: 19.6746, Style Loss: 614.3547
Step [100/1000], Content Loss: 20.5530, Style Loss: 590.7620
Step [110/1000], Content Loss: 21.3093, Style Loss: 569.4674
Step [120/1000], Content Loss: 21.9826, Style Loss: 550.0508
Step [130/1000], Content Loss: 22.5902, Style Loss: 532.1852
Step [140/1000], Content Loss: 23.1442, Style Loss: 515.5713
Step [150/1000], Content Loss: 23.6543, Style Loss: 500.0414
Step [160/1000], Content Loss: 24.1185, Style Loss: 485.5206
Step [170/1000], Content Loss: 24.5

In [24]:
print(np.array(content_features).size)

5


In [40]:
torch.mean((f1-f2)**2)

tensor(0., grad_fn=<MeanBackward1>)