In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms # 将图片转化为张量形式
import torchvision.models as models # 用来加载已经训练好的VGG19

from PIL import Image    # 用来读取图片
from torchvision.utils import save_image # 用来存储突破

### 1： 该模型用到了预先训练好的VGG19模型，并且在任务中该卷积模型的参数是固定的

In [2]:
model =models.vgg19(pretrained=True).features
print(model)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPoo

从上面可以看到VGG19一共有36层，对content 和 style 来说 只要选择其中的五层即可 我们按论文中写的选择的是[0,5,10,19,28]

In [3]:
# 打印其中的一层可以看到它的结构是如何的
print(model[(28)])

Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))


### 下面写一个方法，该方法能够直接得到对应输入的五个输出特征

In [4]:
class VGG(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.chosen_features = [0,5,10,19,28]
        self.model = models.vgg19(pretrained=True).features[:29]  # 该网络的28层后面层去我们一点关系都没有了
        
    
    def forward(self,x):
        
        features = []
        
        for layer_num, layer in enumerate(self.model):
            x = layer(x)
            if layer_num in self.chosen_features:
                features.append(x)
                
        return features

### 2：加载数据集，主要是对content image 和 style image的图片进行预处理

In [5]:
loader = transforms.Compose(
    [
        transforms.Resize((356,356)),
        transforms.ToTensor()
    ]
)

In [6]:
def load_image(image_name):
    image = Image.open(image_name)
    image = loader(image).unsqueeze(0) # 该方法将张量从3维变成4维的，满足网络模型输入参数的输入条件
    
    return image.to(device)

In [7]:
device = torch.device("cuda" if torch.cuda.is_available else "cpu")
print(device)

cuda


In [8]:
content_image = load_image("content_image/face.jpg")
style_image = load_image("style_image/scream.jpg")

In [9]:
print(content_image.shape)
print(style_image.shape)

torch.Size([1, 3, 356, 356])
torch.Size([1, 3, 356, 356])


###  模型初始化

In [10]:
# 超参数的设计
total_steps = 30000
learning_rate = 0.001
alpha = 0.00001
bate = 1

Model = VGG().to(device).eval()
generated = content_image.clone().requires_grad_(True)
optimizer = optim.Adam([generated],lr=learning_rate)

### 模型训练

In [11]:
for step in range(total_steps):
    
    content_loss = 0
    style_loss = 0
    
    generated_features = Model(generated)
    content_features = Model(content_image)
    style_features = Model(style_image)
    
    for gen_feature,content_feature,style_feature in zip(generated_features,content_features,style_features):
        batch_size,channel,height,width = gen_feature.shape
        
        # 先计算content loss
        content_loss += torch.mean((gen_feature-content_feature)**2)
        
        
        # 后计算style loss 用gram矩阵
        G = gen_feature.view(channel,height*width).mm(gen_feature.view(channel,height*width).t())
        S = style_feature.view(channel,height*width).mm(style_feature.view(channel,height*width).t())
        style_loss += torch.mean((G-S)**2)
        
    total_loss = alpha*content_loss+bate*style_loss
    
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    if step % 200 == 0:
        print(content_loss,style_loss,total_loss)
        save_image(generated,"face" + f"/gen_{step}.png")

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


tensor(0., device='cuda:0', grad_fn=<AddBackward0>) tensor(19342456., device='cuda:0', grad_fn=<AddBackward0>) tensor(19342456., device='cuda:0', grad_fn=<AddBackward0>)
tensor(8.7201, device='cuda:0', grad_fn=<AddBackward0>) tensor(938137.1250, device='cuda:0', grad_fn=<AddBackward0>) tensor(938137.1250, device='cuda:0', grad_fn=<AddBackward0>)
tensor(9.6784, device='cuda:0', grad_fn=<AddBackward0>) tensor(453304.5938, device='cuda:0', grad_fn=<AddBackward0>) tensor(453304.5938, device='cuda:0', grad_fn=<AddBackward0>)
tensor(10.2044, device='cuda:0', grad_fn=<AddBackward0>) tensor(301531.4062, device='cuda:0', grad_fn=<AddBackward0>) tensor(301531.4062, device='cuda:0', grad_fn=<AddBackward0>)
tensor(10.5482, device='cuda:0', grad_fn=<AddBackward0>) tensor(234170.8281, device='cuda:0', grad_fn=<AddBackward0>) tensor(234170.8281, device='cuda:0', grad_fn=<AddBackward0>)
tensor(10.8156, device='cuda:0', grad_fn=<AddBackward0>) tensor(193809.7188, device='cuda:0', grad_fn=<AddBackward0>

tensor(13.3871, device='cuda:0', grad_fn=<AddBackward0>) tensor(20249.5000, device='cuda:0', grad_fn=<AddBackward0>) tensor(20249.5020, device='cuda:0', grad_fn=<AddBackward0>)
tensor(13.3985, device='cuda:0', grad_fn=<AddBackward0>) tensor(20034.4062, device='cuda:0', grad_fn=<AddBackward0>) tensor(20034.4082, device='cuda:0', grad_fn=<AddBackward0>)
tensor(13.4186, device='cuda:0', grad_fn=<AddBackward0>) tensor(19704.7109, device='cuda:0', grad_fn=<AddBackward0>) tensor(19704.7129, device='cuda:0', grad_fn=<AddBackward0>)
