## 9.11.8 循环

In [None]:
!pip install icecream
%matplotlib inline
import time
import torch
import torch.nn.functional as F
import torchvision
import numpy as np
from PIL import Image
from icecream import ic

import sys
sys.path.append("..") 
# import d2lzh_pytorch as d2l
from matplotlib import pyplot as plt
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 均已测试

print(device, torch.__version__)
content_img = Image.open('/content/content_image1.png').convert('RGB')
# content_img = Image.open('/content/content_image1.png')
plt.imshow(content_img)
style_img = Image.open('/content/style_image1.png').convert('RGB')
# style_img = Image.open('/content/style_image1.png')
plt.imshow(style_img)
rgb_mean = np.array([0.485, 0.456, 0.406])
rgb_std = np.array([0.229, 0.224, 0.225])

def preprocess(PIL_img, image_shape):
    process = torchvision.transforms.Compose([
        torchvision.transforms.Resize(image_shape),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)])
    
    img_process = process(PIL_img)
    ic(img_process.shape)
    ic(img_process.unsqueeze(dim = 0).shape)

    return img_process.unsqueeze(dim = 0) # (batch_size, 3, H, W)

def postprocess(img_tensor):
    inv_normalize = torchvision.transforms.Normalize(
        mean= -rgb_mean / rgb_std,
        std= 1/rgb_std)
    ic(img_tensor[0].shape)
    to_PIL_image = torchvision.transforms.ToPILImage()
    return to_PIL_image(inv_normalize(img_tensor[0].cpu()).clamp(0, 1))
!echo $TORCH_HOME # 将会把预训练好的模型下载到此处(没有输出的话默认是.cache/torch)
pretrained_net = torchvision.models.vgg19(pretrained=True, progress=True)
pretrained_net


In [None]:
#选的内容图像在VGG中第几层，一个层
# loop = [0, 1]
loop = [0, 1, 5, 7, 10, 12, 14, 16, 19, 21, 23, 25, 28, 30, 32, 34]
len(loop)
start = time()

#plt画布大小
# plt.figure(figsize=(300*5, 450*5))
# plt.figure(figsize=(50, 50)) # in inches

for num, loop_i in enumerate(loop):
  # plt.subplot(5,5,num+1)
  style_layers, content_layers = [0, 5, 10, 19, 28], [loop_i]

  #不加入不会用到的层
  net_list = []
  for i in range(max(content_layers + style_layers) + 1):
      net_list.append(pretrained_net.features[i])
  # 命名为net
  net = torch.nn.Sequential(*net_list)

  def extract_features(X, content_layers, style_layers):
      contents = []
      styles = []
      for i in range(len(net)):
          #每层的特征
          X = net[i](X)
          if i in style_layers: 
              #如果在style层列表里
              styles.append(X)
          if i in content_layers:
              #如果在content层列表里
              contents.append(X)
      return contents, styles

  def get_contents(image_shape, device):
      content_X = preprocess(content_img, image_shape).to(device)
      contents_Y, _ = extract_features(content_X, content_layers, style_layers)
      return content_X, contents_Y

  def get_styles(image_shape, device):
      style_X = preprocess(style_img, image_shape).to(device)
      _, styles_Y = extract_features(style_X, content_layers, style_layers)
      return style_X, styles_Y

  """## 9.11.5 定义损失函数
  ### 9.11.5.1 内容损失
  """

  def content_loss(Y_hat, Y):
      return F.mse_loss(Y_hat, Y)

  """### 9.11.5.2 样式损失"""

  def gram(X):
      num_channels, n = X.shape[1], X.shape[2] * X.shape[3]
      X = X.view(num_channels, n)
      return torch.matmul(X, X.t()) / (num_channels * n)

  def style_loss(Y_hat, gram_Y):
      return F.mse_loss(gram(Y_hat), gram_Y)

  """### 9.11.5.3 总变差损失"""

  def tv_loss(Y_hat):
      return 0.5 * (F.l1_loss(Y_hat[:, :, 1:, :], Y_hat[:, :, :-1, :]) + 
                    F.l1_loss(Y_hat[:, :, :, 1:], Y_hat[:, :, :, :-1]))

  """### 9.11.5.4 损失函数"""

  content_weight, style_weight, tv_weight = 1, 1e3, 10

  def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):
      # 分别计算内容损失、样式损失和总变差损失
      contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(
          contents_Y_hat, contents_Y)]
      styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(
          styles_Y_hat, styles_Y_gram)]
      tv_l = tv_loss(X) * tv_weight
      # 对所有损失求和
      l = sum(styles_l) + sum(contents_l) + tv_l
      return contents_l, styles_l, tv_l, l

  """## 9.11.6 创建和初始化合成图像"""

  class GeneratedImage(torch.nn.Module):
      def __init__(self, img_shape):
          super(GeneratedImage, self).__init__()
          self.weight = torch.nn.Parameter(torch.rand(*img_shape))

      def forward(self):
          return self.weight

  def get_inits(X, device, lr, styles_Y):
      gen_img = GeneratedImage(X.shape).to(device)
      gen_img.weight.data = X.data
      optimizer = torch.optim.Adam(gen_img.parameters(), lr=lr)
      styles_Y_gram = [gram(Y) for Y in styles_Y]
      return gen_img(), styles_Y_gram, optimizer

  """## 9.11.7 训练"""

  # Commented out IPython magic to ensure Python compatibility.
  def train(X, contents_Y, styles_Y, device, lr, max_epochs, lr_decay_epoch):
      print("training on ", device)
      X, styles_Y_gram, optimizer = get_inits(X, device, lr, styles_Y)
      scheduler = torch.optim.lr_scheduler.StepLR(optimizer, lr_decay_epoch, gamma=0.1)
      for i in range(max_epochs):
          start = time.time()
          
          contents_Y_hat, styles_Y_hat = extract_features(
                  X, content_layers, style_layers)
          contents_l, styles_l, tv_l, l = compute_loss(
                  X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)
          
          optimizer.zero_grad()
          l.backward(retain_graph = True)
          optimizer.step()
          scheduler.step()
          
          if i % 50 == 0 and i != 0:
              print('epoch %3d, content loss %.2f, style loss %.2f, '
                    'TV loss %.2f, %.2f sec'
                     % (i, sum(contents_l).item(), sum(styles_l).item(), tv_l.item(),
                      time.time() - start))
      return X.detach()

  image_shape = (300, 450)
  _, content_Y = get_contents(image_shape, device)
  _, style_Y = get_styles(image_shape, device)
  X = preprocess(postprocess(output), image_shape).to(device)
  big_output = train(X, content_Y, style_Y, device, 0.01, 500, 200)

  # d2l.set_figsize((7, 5))
  # d2l.plt.imshow(postprocess(big_output));
  
  plt.imshow(postprocess(big_output))
  # plt.xlabel()
  plt.savefig(f'content image layer {loop_i}.png')

sec = time()-start
ic(f'{sec} s = {sec//60} m')

In [None]:
# plt.figure(figsize=(50, 50))
# plt.imshow(postprocess(big_output))

# # plt.show()

In [None]:
# !cp 'loop.png' /content/drive/MyDrive/QNN
