In [None]:
import torchvision
import torch
import torch.nn as nn
from torchsummary import summary
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
from torchvision.datasets import ImageFolder
from PIL import Image, ImageFile
from tqdm.autonotebook import tqdm
import glob
import os
import torch.backends.cudnn as cudnn
from matplotlib import pyplot as plt
import numpy as np


In [None]:
from google.colab import drive

drive.mount('/content/gdrive')

Mounted at /content/gdrive


Изменяю параметры загрузки изображений для избежания проблем при формировании загрузчика данных

In [None]:
cudnn.benchmark = True
Image.MAX_IMAGE_PIXELS = None 
ImageFile.LOAD_TRUNCATED_IMAGES = True

Инициализация основных переменных

In [None]:
EPOHS = 500
BATCH_SIZE = 4
lr = 1e-4
IMAGE_SIZE = 512
IMAGE_SIZE_CROP = 256
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

Подготовка среды для загрузки датасета с Kaggle

In [None]:
! mkdir ~/.kaggle

In [None]:
!cp /content/gdrive/MyDrive/Project/Style\ transfer/kaggle.json ~/.kaggle

In [None]:
!chmod 600 ~/.kaggle/kaggle.json

датасет был взят с соревнования: https://www.kaggle.com/c/painter-by-numbers

In [None]:
!kaggle competitions download -c painter-by-numbers -f train.zip

In [None]:
!unzip /content/train.zip -d /content/style_dataset

In [None]:
!rm /content/train.zip

Проверяем размер датасета

In [None]:
len(os.listdir('/content/style_dataset/train'))

79433

второй датасета для обучения COCO-2017

In [None]:
!pip install fiftyone

In [None]:
!pip uninstall  opencv-python-headless==4.5.5.62

In [None]:
!pip install opencv-python-headless==4.1.2.30

In [None]:
import fiftyone as fo
import fiftyone.zoo as foz

dataset = foz.load_zoo_dataset(
    "coco-2017",
    split = 'train'
)

Функция для нахождения среднеквадратического отклонения

In [None]:
def find_mean_std(input, eps = 1e-5):
  batch_size, channels, height, weight = input.size()
  input_std = torch.sqrt(input.view(batch_size, channels,-1).var(dim=2) + eps).view(batch_size, channels,1,1)
  input_mean = torch.mean(input.view(batch_size, channels,-1), dim = 2).view(batch_size, channels,1,1)
  
  return input_mean, input_std

Функция для расчет адаптивной поканальной нормализации

In [None]:
def AdaIN(content, style):
  content_mean, content_std = find_mean_std(content)
  style_mean, style_std = find_mean_std(style)

  return style_std * ((content - content_mean) / content_std ) + style_mean

In [None]:
class Decoder(nn.Module):
  def __init__(self):
    super().__init__()

    self.model = nn.Sequential(

        nn.Conv2d(512,256, kernel_size = 3, stride = 1, padding = 1, padding_mode='reflect'),
        nn.ReLU(inplace = True),
        nn.Upsample(scale_factor = 2, mode = 'nearest'),

        nn.Conv2d(256,256, kernel_size = 3, stride = 1, padding = 1, padding_mode='reflect'),
        nn.ReLU(inplace = True),
        nn.Conv2d(256,256, kernel_size = 3, stride = 1, padding = 1, padding_mode='reflect'),
        nn.ReLU(inplace = True),
        nn.Conv2d(256,256, kernel_size = 3, stride = 1, padding = 1, padding_mode='reflect'),
        nn.ReLU(inplace = True),
        nn.Conv2d(256,128, kernel_size = 3, stride = 1, padding = 1, padding_mode='reflect'),
        nn.ReLU(inplace=True),
        nn.Upsample(scale_factor = 2,mode = 'nearest'),
 
        nn.Conv2d(128,128, kernel_size = 3, stride = 1, padding = 1, padding_mode='reflect'),
        nn.ReLU(inplace = True),
        nn.Conv2d(128,64,kernel_size=3, stride = 1, padding = 1, padding_mode='reflect'),
        nn.ReLU(inplace = True),
        nn.Upsample(scale_factor = 2, mode='nearest'),
        nn.Conv2d(64,64, kernel_size = 3, stride = 1, padding = 1, padding_mode='reflect'),
        nn.ReLU(inplace = True),
        nn.Conv2d(64,3, kernel_size = 3, padding = 1, padding_mode='reflect'),  
    )

  def forward(self,x):
    return self.model(x)

In [None]:
decoder = Decoder()

In [None]:
class Net(nn.Module):
  def __init__(self,decoder):
    super().__init__()
    self.encoder = torchvision.models.vgg19(pretrained=True).features[:21]
    self.decoder = decoder
    self.mse_loss = nn.MSELoss()

    #заменяем тип паддинга в энкодере
    for module in self.encoder.modules():
        classname = module.__class__.__name__
        if 'Conv' in classname:
            module.padding_mode = 'reflect'
    #энкодер не обучается
    for parameter in self.encoder.parameters():
      parameter.requires_grad_(False)

  def decode(self,x):
    return self.decoder(x)

  def encode(self,x):
    return self.encoder(x)

  def encode_per_layer(self,x):

    features = []

    for layer_num,layer in enumerate(self.encoder):
      x = layer(x)

      if layer_num in [1,6,11,21]:
        features.append(x)

    return features

  def content_loss(self,x, content):
    return self.mse_loss(x, content)


  def style_loss(self, x, style):
    mean_st, std_st = find_mean_std(style)
    mean_inp, std_inp = find_mean_std(x)

    return self.mse_loss(mean_inp, mean_st) + self.mse_loss(std_inp, std_st)

  def forward(self, content, style, alpha = 1.0):
    style_f = self.encode_per_layer(style)
    content_f = self.encode_per_layer(content)

    normal = AdaIN(content_f[-1],style_f[-1])

    generated = self.decoder((1 - alpha) * content_f[-1] + alpha * normal)
    generated = self.encode_per_layer(generated[-1])

    loss_cont = self.content_loss(generated[-1], normal)
    loss_style = self.style_loss(generated[0], style_f[0])
    
    for layer in range(1,4):
        loss_style += self.style_loss(generated[layer], style_f[layer])

    return loss_cont, loss_style

In [None]:
net = Net(decoder).to(DEVICE)
optimizer = torch.optim.Adam(net.decoder.parameters(), lr =2e-4)

In [None]:
transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.RandomCrop((IMAGE_SIZE_CROP),(IMAGE_SIZE_CROP)),
    transforms.ToTensor(),
    transforms.Normalize(mean = (0.5,0.5,0.5),std = (0.5,0.5,0.5]
)

In [None]:
DIR_content = '/root/fiftyone/coco-2017/train'
DIR_style = '/content/style_dataset'

In [None]:
#проверка файлов
for image in style_dataset:
    for file,label in tqdm(image):
        try:
            im = ImageFile.Image.open(file)
            im2 = im.convert('RGB')
        except DecompressionBombError:
            print("Cannot load : {}".format(fn))

In [None]:
#часть файлов в датасете оказалась поврежденная, поэтому их сразу пришлось исключить из датасета
corrupted = ['3917','41945','79499','91033','92899','95347','101947']
for file in corrupted:
  os.system(f'rm /content/style_dataset/train/{file}.jpg')

In [None]:
content_dataset = ImageFolder(DIR_content, transform=transform)
content_loader = DataLoader(content_dataset, batch_size=BATCH_SIZE,  pin_memory = True, num_workers=0)

In [None]:
style_dataset = ImageFolder(DIR_style, transform = transform)
style_loader = DataLoader(style_dataset, batch_size = BATCH_SIZE, pin_memory = True, num_workers=0)

In [None]:
def test_transform(size, crop):
    transform_list = []
    if size != 0:
        transform_list.append(transforms.Resize(size))
    if crop:
        transform_list.append(transforms.CenterCrop(size))
    transform_list.append(transforms.ToTensor())
    transform = transforms.Compose(transform_list)
    return transform
    
def denorm(x):
  stats = (0.5,0.5,0.5),(0.5,0.5,0.5)
  return x*stats[0][0] + stats[0][1]

transform= test_transform(2048, True)

#картинки для отслеживания качества переноса стиля
content_image = Image.open('./content.png')
style_image = Image.open('./style.jpg')
content = transform(cont).unsqueeze(0).cuda()
style = transform(style).unsqueeze(0).cuda()

In [None]:
iteration = 0
for epoch in tqdm(range(4)):
    for (batch_st, _) , (batch_cont, _) in tqdm(zip(style_loader, content_loader)):
      net.decoder.train()
      iteration +=1

      batch_st = batch_st.to(DEVICE)
      batch_cont = batch_cont.to(DEVICE)
      loss_cont, loss_st = net(batch_cont, batch_st)
      loss = loss_cont + 10 * loss_st


      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      if iteration % 500 == 0 :
        print('Iteration: {:.4f} Loss style: {:.4f}, Loss content: {:.4f}, Full loss: {:.4f}'.format(
            iteration, loss_st.item(), loss_cont.item(), loss.item()))
        
        content_f = net.encoder(content)
        style_f = net.encoder(style)
        normalized = AdaIN(content_f,  style_f)

        net.decoder.eval()
        with torch.no_grad():
          out = net.decoder(normalized)
        
        torch.save({
          'decoder':net.decoder.state_dict(),
          'optim':optimizer.state_dict()}, './check.tar')
        plt.subplot(1,3,1)
        plt.imshow(content_image)
        plt.subplot(1,3,2)
        plt.imshow(style_image)
        plt.subplot(1,3,3)
        plt.imshow((denorm(out).squeeze(0).cpu().numpy().transpose(1,2,0)* 255).astype(np.uint8))
        plt.show();