In [None]:
import os
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms.v2 as tt
import torch
import torchvision
import torch.nn as nn
import cv2
from tqdm.notebook import tqdm
from torchvision.utils import save_image
from torchvision.utils import make_grid
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import PIL
import math
from IPython import display
import pandas as pd
import requests
from PIL import Image
from io import BytesIO
import os.path
from torch import linalg as LA
%matplotlib inline

sns.set(style='darkgrid', font_scale=1.2)

In [None]:
!pip install onnx
!pip install onnxscript

# Enviroment Variables

In [None]:
coco_dataset_path = "/kaggle/input/coco-2017-dataset/coco2017/train2017"
style_image_url = "https://uploads4.wikiart.org/00142/images/vincent-van-gogh/the-starry-night.jpg!Large.jpg"
style_representation_file_path = "style.jpg"
model_file_path = "model.onnx"

# Data Preporation

In [None]:
class StyleDataset(torch.utils.data.Dataset):
    def __init__(self, urls):
        super(StyleDataset, self).__init__()
        
        self.urls = urls
        
    def __getitem__(self, index):
        response = requests.get(self.urls[index])
        img = Image.open(BytesIO(response.content))
        return transforms(img)
    
    def __len__(self):
        return len(self.urls)

In [None]:
class ContentDataset(torch.utils.data.Dataset):
    def __init__(self, root_path, transforms):
        super(ContentDataset, self).__init__()
        
        self.paths = []
        self.transforms = transforms

        for root, dirs, files in os.walk(root_path):
            for file in tqdm(files):
                self.paths.append(os.path.join(root, file))
        
    def __getitem__(self, index):
        img = Image.open(self.paths[index])
        return self.transforms(img)
    
    def __len__(self):
        return len(self.paths)

In [None]:
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
img_size = 256

transforms = tt.Compose([
    tt.ToImage(),
    tt.ToDtype(dtype=torch.float32, scale=True),
    tt.Normalize(mean=mean, std=std)
])

transforms2 = tt.Compose([
    tt.ToImage(),
    tt.Resize(img_size, antialias=True),
    tt.CenterCrop(img_size),
    tt.ToDtype(dtype=torch.float32, scale=True),
    tt.Normalize(mean=mean, std=std)
])

In [None]:
# В теории можно усреднять стиль по нескольким изображениям
urls = [style_image_url]
style_dataset = StyleDataset(urls)
content_dataset = ContentDataset(coco_dataset_path, transforms=transforms2)

style_loader = DataLoader(style_dataset, batch_size=1)
content_loader = DataLoader(content_dataset, batch_size=4, shuffle=True)

# Monitoring Utilities

In [None]:
def denorm_and_permute(img_tensors):
    new_image = img_tensors.permute(1, 2, 0) * torch.tensor(std) + torch.tensor(mean)
    return new_image

In [None]:
def show_imgs(images):
    # display.clear_output(wait=True)
    plt.figure(figsize=(15,8))
    for i, img in enumerate(images):
        plt.subplot(1, len(images), i + 1)
        plt.imshow(denorm_and_permute(img))
        plt.axis('off')
        plt.tight_layout()
    plt.show()

In [None]:
class Watcher():
    def __init__(self, rows=2, columns=10):
        self.max_img_numbers = rows * columns
        self.fig, self.axes = plt.subplots(rows, columns, figsize=(20,5))
        
        try:
            self.axes = self.axes.flat
        except:
            self.axes = [self.axes]
        
        for ax in self.axes:
            ax.set_visible(False)

        self.dh = display.display(self.fig, display_id=True)


    def show(self, images_cpu, labels=None, img_numbers=None, message=None):
        img_number = min(img_numbers, len(images_cpu)) if img_numbers is not None else len(images_cpu)
        img_number = min(img_number, self.max_img_numbers)

        if message is not None:
            self.fig.suptitle(message)

        for i in range(img_number):
            self.axes[i].set_visible(True)
            self.axes[i].clear()
            self.axes[i].imshow(denorm_and_permute(images_cpu[i]))
            self.axes[i].axis('off')

        self.dh.update(self.fig)
        plt.close()


# Style and Content Feature Extraction


## Feature Extractor class

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
vgg16 = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT)
vgg16 = vgg16.to(device)

In [None]:
def get_gram(x):
    N, c, h, w = x.size()
    x = x.view(N, c, h*w)
    y = x.transpose(1, 2)
    return torch.bmm(x, y) / (c * h * w)


class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        self._image_mean = torch.tensor([0.485, 0.456, 0.406], device=device)
        self._image_std = torch.tensor([0.229, 0.224, 0.225], device=device)

        self.layers = vgg16.features

        for param in self.layers.parameters():
            param.requires_grad = False

    def extract_features(self, images, active_style_layers=[1, 1, 1, 1], active_content_layers=[0, 0, 1, 0]):
        style_features = []
        content_features = []
        used_layers_number = 0
        
        for i, layer in enumerate(self.layers):
            if used_layers_number == 4:
                break

            images = layer(images)

            if i in [3, 8, 15, 22]:
                n, c, h, w = images.size()
            
                if active_content_layers[used_layers_number] != 0:
                    content_features.append(images)
                else:
                    content_features.append(torch.zeros(1, device=device))

                if active_style_layers[used_layers_number] != 0:
                    style_feature = get_gram(images)
                    style_features.append(style_feature)
                else:
                    style_features.append(torch.zeros(1, device=device))
                
                used_layers_number += 1

        return content_features, style_features

    def extract_style_features(self, images, active_style_layers=[1, 1, 1, 1]):
        return self.extract_features(
            images, 
            active_style_layers=active_style_layers, 
            active_content_layers=[0,0,0,0])[1]

    def extract_content_features(self, images, active_content_layers=[0, 0, 1, 0]):
        return self.extract_features(
            images, 
            active_style_layers=[0,0,0,0], 
            active_content_layers=active_content_layers)[0]


In [None]:
feature_extractor = FeatureExtractor().to(device)

## Reference Style Definition

In [None]:
style_features_mean = [torch.tensor([], device=device) for i in range(4)]
count = 0

for style_image in tqdm(style_loader):
    count += 1
    style_image = style_image.to(device)
#     (C, H, W) -> [(C1, C1), (C2, C2), (C3, C3), (C4, C4), (C5, C5)]
    style_feature = feature_extractor.extract_style_features(style_image)
    for i in range(4):
        if style_features_mean[i].size()[0] == 0:
            style_features_mean[i] = style_feature[i].detach()
        else:
            style_features_mean[i] += style_feature[i].detach()

for i in range(4):
    style_features_mean[i] /= count

# Loss Functions

In [None]:
def content_loss_fn(generated_content_features, content_features, content_wl):
    content_loss = torch.zeros(1, device=device)

    for i in range(len(content_features)):
        if i >= len(content_wl) or content_wl[i] == 0:
            continue
    
        content_loss += nn.functional.mse_loss(content_features[i], generated_content_features[i]) * content_wl[i]

    return content_loss

def style_loss_fn(generated_style_features, style_features, style_wl):
    style_loss = torch.zeros(1, device=device)
    batch_size = generated_style_features[0].size(0)

    for i in range(len(style_features)):
        if style_wl[i] == 0:
            continue

        style_loss += torch.sum(torch.linalg.matrix_norm(generated_style_features[i] - style_features[i].tile((batch_size, 1, 1)))**2) * style_wl[i]
    return style_loss
        
def loss_fn_extended(generated_image, generated_image_features, content_features, content_wl, style_features, style_wl, alpha=1e-3):

    content_loss = content_loss_fn(generated_image_features[0], content_features, content_wl)
    style_loss = style_loss_fn(generated_image_features[1], style_features, style_wl)

    return style_loss + alpha * content_loss, style_loss.item(), alpha * content_loss.item()

# Getting the expected Style's Representation

In [None]:
def show_representation(reference_style_feature, loss_fn, iterations=500, wl=[1, 1, 1, 1]):
    torch.cuda.empty_cache()
    
    style_wl = torch.tensor(wl, device=device)
    
    image = torch.rand((1, 3, 512, 512), device=device, requires_grad=True)
    optimizer = torch.optim.Adam([image], lr=8e-2)
    
    watcher = Watcher(columns=1, rows=1)

    for iteration in tqdm(range(iterations)):
        optimizer.zero_grad()
                        
        style_features = feature_extractor.extract_style_features(image)
        
        loss = loss_fn(style_features, reference_style_feature, style_wl)

        loss.backward()
        optimizer.step()
        
        if iteration % 100 == 0:
            watcher.show([image[0].detach().cpu()], message=f"loss: {loss.item()}")
            
    return image[0].detach().cpu()

In [None]:
image = show_representation(style_features_mean, style_loss_fn, iterations=5000)
img = denorm_and_permute(image)
img = (img * 255).permute(2, 0, 1).clamp(0, 255).to(dtype=torch.uint8)
torchvision.io.write_jpeg(img, style_representation_file_path)

# Style Transfer Model

## Residual Block

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        
        self.conv0 = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding='same'),
            nn.ReLU(inplace=True)
        )
        
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding='same')
        
    def forward(self, x):
        y = self.conv0(x)
        y = self.conv1(y)
        x = x + y
        return x

## Convolution and Deconvolution layers

In [None]:
class ConvLayer(nn.Module):
    def __init__(self, in_dim, out_dim, kernel_size, stride, norm = "instance"):
        super(ConvLayer, self).__init__()
        padding_size = kernel_size // 2
        self.reflection_pad = nn.ReflectionPad2d(padding_size)

        self.conv_layer = nn.Conv2d(in_dim, out_dim, kernel_size, stride)

        if norm == "instance":
            self.norm_layer = nn.InstanceNorm2d(out_dim, affine = True)
        else:
            self.morn_layer = nn.Identity()

    def forward(self, x):
        x = self.reflection_pad(x)
        x = self.conv_layer(x)
        out = self.norm_layer(x)
        return out

In [None]:
class DeconvLayer(nn.Module):
    def __init__(self, in_dim, out_dim, kernel_size, stride, output_padding, norm="instance"):
        super(DeconvLayer, self).__init__()

        # Transposed Convolution
        padding_size = kernel_size // 2
        self.conv_transpose = nn.ConvTranspose2d(in_dim, out_dim, kernel_size, stride, padding_size, output_padding)

        # Normalization Layers
        if (norm == "instance"):
            self.norm_layer = nn.InstanceNorm2d(out_dim, affine = True)
        else:
            self.norm_layer = nn.Identity()

    def forward(self, x):
        x = self.conv_transpose(x)
        out = self.norm_layer(x)
        return out

## Model

In [None]:
class StyleTransferModel(nn.Module):
    def __init__(self):
        super(StyleTransferModel, self).__init__()

        self.conv0 = nn.Sequential(
            ConvLayer(3, 32, 9, 1),
            nn.ReLU(inplace=True)
        )

        self.conv1 = nn.Sequential(
            ConvLayer(32, 64, 3, 2),
            nn.ReLU(inplace=True)
        )

        self.conv2 = nn.Sequential(
            ConvLayer(64, 128, 3, 2),
            nn.ReLU(inplace=True)
        )
        
        
        self.residuals = nn.ModuleList([ResidualBlock(128) for i in range(5)])
        
        self.deconv0 = nn.Sequential(
            DeconvLayer(128, 64, 3, 2, 1),
            nn.ReLU(inplace=True),
        )
        
        self.deconv1 = nn.Sequential(
            DeconvLayer(64, 32, 3, 2, 1),
            nn.ReLU(inplace=True),
        )
        
        self.deconv2 = nn.Sequential(
            DeconvLayer(32, 3, 9, 1, 0, norm=None),
            nn.Tanh()
        )
        
    def forward(self, x):
        x = self.conv0(x)
        x = self.conv1(x)
        x = self.conv2(x)

        for residual_layer in self.residuals:
            x = residual_layer(x)
            
        x = self.deconv0(x)
        x = self.deconv1(x)
        x = self.deconv2(x)
            
        return x

# Training

In [None]:
def train(
    model, 
    content_loader, 
    reference_style_feature, 
    loss_fn, 
    lr=1e-3, 
    alpha=1e-3, 
    content_wl_list=[0, 0, 1, 0], 
    style_wl_list=[1/4, 1/4, 1/4, 1/4],
    file_name="model.pt",
    watch=False,
    save_model=False
):
    torch.cuda.empty_cache()

    content_wl = torch.tensor(content_wl_list, device=device)
    style_wl = torch.tensor(style_wl_list, device=device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    iteration = 0
    
    watcher=Watcher()
        
    for batch in tqdm(content_loader):
        torch.cuda.empty_cache()

        model.train()

        optimizer.zero_grad()
        
        batch = batch.to(device)
        
        with torch.no_grad():
            content_features = feature_extractor.extract_content_features(batch, active_content_layers=content_wl_list)
            
        content_features = [t.detach() for t in content_features]
                
        batch = batch #+ torch.rand_like(batch)
        batch = model.forward(batch)

        batch_features = feature_extractor.extract_features(batch, active_style_layers=style_wl_list, active_content_layers=content_wl_list)
        
        loss, style_loss, content_loss = loss_fn(
            batch, 
            batch_features, 
            content_features, 
            content_wl, 
            reference_style_feature, 
            style_wl,
            alpha=alpha
        )

        loss.backward()
        optimizer.step()
        
        if iteration % 100 == 0 and watch:
            model.eval()
            torch.cuda.empty_cache()
            with torch.no_grad():
                fixed_batch = next(iter(content_loader))[0:3].to(device)
                fixed_result = model(fixed_batch)

            watcher.show([*fixed_batch.cpu(), *fixed_result.detach().cpu()], message=f"content loss: {content_loss / len(batch)}, style loss: {style_loss / len(batch)}")
        
            del fixed_result, fixed_batch
            
        if iteration % 500 == 0 and save_model:
            torch.save(model, file_name)
        
        iteration += 1
        
        

# Training process

In [None]:
model = StyleTransferModel().to(device)

In [None]:
train(
    model, 
    content_loader, 
    style_features_mean,
    loss_fn_extended, 
    lr=1e-3,
#     good value is +-1e-2
    alpha=3e-2,
    content_wl_list=[0, 0, 1, 0], 
    style_wl_list=[1/4, 1/4, 1/4, 1/4],
    file_name="model.pt",
    watch=False
)
;

# Model saving

In [None]:
model.cpu()
model.eval()

model_input = torch.randn(1, 3, 256, 256, dtype=torch.float32)
torch.onnx.export(
    model,
    model_input,
    model_file_path,
    input_names=["image"],
    output_names=["result"],
    dynamic_axes={
        "image": {0: "batch_size", 2: "image_height", 3: "image_width"},
        "result": {0: "batch_size", 2: "image_height", 3: "image_width"},
    }
)