<a href="https://colab.research.google.com/github/VadimFarutin/neural-style-transfer/blob/neuro-template/neural-style-transfer/notebooks/Multi_Style_Transfer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from __future__ import print_function

import numpy as np
from tqdm import tnrange, tqdm_notebook
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from PIL import Image
import matplotlib.pyplot as plt

import torchvision.transforms as transforms
import torchvision.models as models

import copy
import os

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
cnn = models.vgg19(pretrained=True).features.to(device).eval()

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:18<00:00, 31.2MB/s]


In [5]:
!wget "http://images.cocodataset.org/zips/train2014.zip"

--2019-12-24 22:37:49--  http://images.cocodataset.org/zips/train2014.zip
Resolving images.cocodataset.org (images.cocodataset.org)... 52.216.10.35
Connecting to images.cocodataset.org (images.cocodataset.org)|52.216.10.35|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 13510573713 (13G) [application/zip]
Saving to: ‘train2014.zip’


2019-12-24 22:42:59 (41.5 MB/s) - ‘train2014.zip’ saved [13510573713/13510573713]



In [0]:
import zipfile
with zipfile.ZipFile("train2014.zip", 'r') as zip_ref:
    zip_ref.extractall("train2014")

In [0]:
class ImageLoader():
    IMAGE_SIZE = 128 if torch.cuda.is_available() else 128

    def __init__(self):
        self.loader = transforms.Compose([
            transforms.Resize((ImageLoader.IMAGE_SIZE, ImageLoader.IMAGE_SIZE)),
            transforms.ToTensor()]
        )

    def load(self, path):
        image = Image.open(path)
        image = self.loader(image).unsqueeze(0)
        image = image.to(device, torch.float)
        return image


In [0]:
def imshow(tensor, title):
    unloader = transforms.ToPILImage()

    image = tensor.cpu().clone()
    image = image.squeeze(0)
    image = unloader(image)
    plt.figure()
    plt.imshow(image)
    plt.xticks([]); plt.yticks([]);
    plt.title(title)


In [0]:
class ContentLoss(nn.Module):
    def __init__(self, target):
        super(ContentLoss, self).__init__()
        self.target = target.detach()

    def forward(self, input):
        self.loss = F.mse_loss(input, self.target)
        return input

In [0]:
class StyleLoss(nn.Module):
    def gram_matrix(input):
        b, c, h, w = input.size()
        features = input.view(b * c, h * w) 
        G = torch.mm(features, features.t())

        return G.div(b * c * h * w)

    def __init__(self, target_feature):
        super(StyleLoss, self).__init__()
        self.target = StyleLoss.gram_matrix(target_feature).detach()

    def forward(self, input):
        G = StyleLoss.gram_matrix(input)
        self.loss = F.mse_loss(G, self.target)
        return input


In [0]:
class Normalization(nn.Module):
    MEAN = torch.tensor([0.485, 0.456, 0.406]).to(device)
    STD = torch.tensor([0.229, 0.224, 0.225]).to(device)

    def __init__(self, mean=None, std=None):
        super(Normalization, self).__init__()

        if mean is None:
            mean = Normalization.MEAN
        if std is None:
            std = Normalization.STD

        self.mean = mean.clone().detach().view(-1, 1, 1)
        self.std = std.clone().detach().view(-1, 1, 1)

    def forward(self, img):
        return (img - self.mean) / self.std


In [0]:
class BaseModel(nn.Module):
    def __init__(self, cnn, style_img, content_img, content_layers, style_layers):
        super(BaseModel, self).__init__()

        cnn = copy.deepcopy(cnn)
        normalization = Normalization().to(device)
        model = nn.Sequential(normalization)

        content_losses = []
        style_losses = []

        last_loss_layer = 0
        conv_cnt = 0

        for layer in cnn.children():
            if isinstance(layer, nn.Conv2d):
                conv_cnt += 1
                name = 'conv_{}'.format(conv_cnt)
            elif isinstance(layer, nn.ReLU):
                name = 'relu_{}'.format(conv_cnt)
                layer = nn.ReLU(inplace=False)
            elif isinstance(layer, nn.MaxPool2d):
                name = 'pool_{}'.format(conv_cnt)
            elif isinstance(layer, nn.BatchNorm2d):
                name = 'bn_{}'.format(conv_cnt)
            else:
                name = 'unknown_{}'.format(conv_cnt)
                
            model.add_module(name, layer)

            if name in content_layers:
                target = model(content_img).detach()
                content_loss = ContentLoss(target)
                model.add_module("content_loss_{}".format(conv_cnt), content_loss)
                content_losses.append(content_loss)
                last_loss_layer = len(model)

            if name in style_layers:
                target_feature = model(style_img).detach()
                style_loss = StyleLoss(target_feature)
                model.add_module("style_loss_{}".format(conv_cnt), style_loss)
                style_losses.append(style_loss)
                last_loss_layer = len(model)

        model = model[:(last_loss_layer)]

        self.model = model
        self.content_losses = content_losses
        self.style_losses = style_losses

    def get_content_losses(self, ):
        return self.content_losses

    def get_style_losses(self, ):
        return self.style_losses

    def forward(self, input):
        return self.model(input)

In [0]:
class GramMatrix(nn.Module):
    def forward(self, y):
        size = y.size()
        features = y.view(size[0], size[1], size[2] * size[3])
        return features.bmm(features.transpose(1, 2)) / (size[1] * size[2] * size[3])

In [0]:
class InspirationLayer(nn.Module):
    def __init__(self, C):
        super(InspirationLayer, self).__init__()
        self.weight = nn.Parameter(torch.Tensor(1, C, C), requires_grad=True)
        self.G = Variable(torch.Tensor(1, C, C), requires_grad=True)
        self.C = C
        self.reset_parameters()

    def reset_parameters(self):
        self.weight.data.uniform_(0.0, 0.02)

    def setTarget(self, target):
        self.G = target

    def forward(self, X):
        self.P = torch.bmm(self.weight.expand_as(self.G), self.G)
        return torch.bmm(self.P.transpose(1, 2).expand(X.size(0), self.C, self.C), X.view(X.size(0), X.size(1), -1)).view_as(X)

In [0]:
class ConvLayer(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(ConvLayer, self).__init__()
        self.reflection_padding = nn.ReflectionPad2d(int(np.floor(kernel_size / 2)))
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        return self.conv2d(self.reflection_padding(x))

In [0]:
class UpConvLayer(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
        super(UpConvLayer, self).__init__()
        self.upsample = upsample
        if upsample:
            self.upsample_layer = torch.nn.Upsample(scale_factor=upsample)
        self.reflection_padding_size = int(np.floor(kernel_size / 2))
        if self.reflection_padding_size != 0:
            self.reflection_padding = nn.ReflectionPad2d(self.reflection_padding_size)
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        if self.upsample:
            x = self.upsample_layer(x)
        if self.reflection_padding_size != 0:
            x = self.reflection_padding(x)
        return self.conv2d(x)

In [0]:
class PreResBlock(nn.Module):
    def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=nn.BatchNorm2d):
        super(PreResBlock, self).__init__()
        self.expansion = 4
        self.downsample = downsample
        if self.downsample:
            self.residual_layer = nn.Conv2d(inplanes, planes * self.expansion, kernel_size=1, stride=stride)
        self.conv_block = nn.Sequential(
            norm_layer(inplanes),
            nn.ReLU(inplace=True),
            nn.Conv2d(inplanes, planes, kernel_size=1, stride=1),
            norm_layer(planes),
            nn.ReLU(inplace=True),
            ConvLayer(planes, planes, kernel_size=3, stride=stride),
            norm_layer(planes),
            nn.ReLU(inplace=True),
            nn.Conv2d(planes, planes * self.expansion, kernel_size=1, stride=1))
        
    def forward(self, x):
        if self.downsample:
            residual = self.residual_layer(x)
        else:
            residual = x
        return residual + self.conv_block(x)

In [0]:
class UpResBlock(nn.Module):
    def __init__(self, inplanes, planes, stride=2, norm_layer=nn.BatchNorm2d):
        super(UpResBlock, self).__init__()
        self.expansion = 4
        self.residual_layer = UpConvLayer(inplanes, planes * self.expansion, kernel_size=1, stride=1, upsample=stride)
        self.conv_layers = nn.Sequential(
            norm_layer(inplanes),
            nn.ReLU(inplace=True),
            nn.Conv2d(inplanes, planes, kernel_size=1, stride=1),
            norm_layer(planes),
            nn.ReLU(inplace=True),
            UpConvLayer(planes, planes, kernel_size=3, stride=1, upsample=stride),
            norm_layer(planes),
            nn.ReLU(inplace=True),
            nn.Conv2d(planes, planes * self.expansion, kernel_size=1, stride=1))

    def forward(self, x):
        return  self.residual_layer(x) + self.conv_layers(x)

In [0]:
class Net(nn.Module):
    def __init__(self, input_nc=3, output_nc=3, ngf=64, norm_layer=nn.InstanceNorm2d, n_blocks=6, gpu_ids=[]):
        super(Net, self).__init__()
        self.gpu_ids = gpu_ids
        self.gram = GramMatrix()

        block = PreResBlock
        upblock = UpResBlock
        expansion = 4

        self.model1 = nn.Sequential(ConvLayer(input_nc, 64, kernel_size=7, stride=1),
                            norm_layer(64),
                            nn.ReLU(inplace=True),
                            block(64, 32, 2, 1, norm_layer),
                            block(128, ngf, 2, 1, norm_layer))

        model = []
        self.ins = InspirationLayer(ngf * 4)
        model += [self.model1]
        model += [self.ins]    

        for i in range(n_blocks):
            model += [block(ngf * 4, ngf, 1, None, norm_layer)]
        
        model += [upblock(ngf * 4, 32, 2, norm_layer),
                            upblock(128, 16, 2, norm_layer),
                            norm_layer(64),
                            nn.ReLU(inplace=True),
                            ConvLayer(64, output_nc, kernel_size=7, stride=1)]

        self.model = nn.Sequential(*model)

    def setTarget(self, Xs):
        F = self.model1(Xs)
        G = self.gram(F)
        self.ins.setTarget(G)

    def forward(self, input):
        return self.model(input)

In [0]:
content_layer_names = ['conv_4']
style_layer_names = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']

In [0]:
class Optimizer():
    def __init__(self, model):
        self.optimizer = optim.Adam(model.parameters(), lr=1e-3)
    
    def zero_grad(self):
        self.optimizer.zero_grad()

    def step(self, optimizer_step):
        self.optimizer.step(optimizer_step)


In [0]:
def fit(dir, model, epoch_cnt, style_weight, content_weight, cnn, content_layer_names, style_layer_names):
    loader = ImageLoader()
    loss_values = []
    optimizer = Optimizer(model)

    style_imgs = []
    for style_file in os.listdir("./styles/"):
        style_imgs.append(loader.load("./styles/" + style_file))

    for epoch in tnrange(epoch_cnt):
        def optimizer_step(input_img):
            input_img.data.clamp_(0, 1)

            optimizer.zero_grad()
            base = BaseModel(cnn, style_img, input_img, content_layer_names, style_layer_names)
            base = base.to(device)
            style_losses = base.get_style_losses()
            content_losses = base.get_content_losses()

            base(model(input_img))

            style_score = 0
            content_score = 0

            for sl in style_losses:
                style_score += sl.loss
            for cl in content_losses:
                content_score += cl.loss

            loss = style_score * style_weight + content_score * content_weight
            loss.backward()
            loss_values.append(loss.item())

            return loss.item()
        
        im_cnt = 0
        for file in tqdm(os.listdir("./train2014/train2014")):
            # im_cnt += 1
            # if (im_cnt == 100):
            #     break

            content_img = loader.load("./train2014/train2014/" + file)
            if content_img.shape[1] != 3:
                continue

            for style_img in style_imgs:
                style_img_copy = style_img.clone().detach()
                content_img_copy = content_img.clone().detach()
                model.setTarget(style_img_copy)                

                def opt_step():
                    return optimizer_step(content_img_copy)

                optimizer.step(opt_step)
        
        # print("Epoch " + str(epoch))

    # im_cnt = 0
    # for file in tqdm(os.listdir("./train2014/train2014")):
    #     im_cnt += 1
    #     if (im_cnt == 10):
    #         break
    #     input_img = loader.load("./train2014/train2014/" + file)
    #     input_img.data.clamp_(0, 1)
    #     #   optimizer.zero_grad()
    #     #   base = BaseModel(cnn, style_img, input_img, content_layer_names, style_layer_names)
    #     #   base = base.to(device)
    #     #   output = base(model(input_img))
    #     output = model(input_img)
    #     imshow(output, title='Image')

    return loss_values


In [0]:
def plot_loss_values(loss_values):
    plt.plot(np.arange(len(loss_values)), loss_values, color='blue')
    plt.title("Loss values")
    plt.xlabel("iteration")
    plt.ylabel("loss")
    plt.show()

In [0]:
model = Net(ngf=64, n_blocks=3)
model = model.to(device)
epoch_cnt = 2
content_weight = 1 # alpha
style_weight = 1000000 # beta

In [0]:
loss_values = fit("./train2014/train2014", model, epoch_cnt, style_weight, content_weight, cnn, content_layer_names, style_layer_names)
plot_loss_values(loss_values)

loader = ImageLoader()
# Images from Gatys
style_img = loader.load("./data/the-starry-night.jpg")
content_img = loader.load("./data/tubingen.jpg")
model.setTarget(style_img)
output = model(content_img)

imshow(style_img, title='Style Image')
imshow(content_img, title='Content Image')
imshow(output, title='Image')

In [0]:
loss_values = fit("./train2014/train2014", model, epoch_cnt, style_weight, content_weight, cnn, content_layer_names, style_layer_names)
plot_loss_values(loss_values)

loader = ImageLoader()
# Images from Gatys
style_img = loader.load("./data/the-starry-night.jpg")
content_img = loader.load("./data/tubingen.jpg")
model.setTarget(style_img)
output = model(content_img)

imshow(style_img, title='Style Image')
imshow(content_img, title='Content Image')
imshow(output, title='Image')