# Imports

In [None]:
!pip install torchinfo

In [None]:
import collections
import glob
import io
import os
import pickle
import random
import shutil
import time
import warnings
import zipfile
warnings.simplefilter(action="ignore", category=FutureWarning)

import IPython
import IPython.display

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
%matplotlib inline

import skimage
import skimage.color
import skimage.io
import skimage.transform

import tqdm.auto as tqdm

import torch
import torch.nn.functional as F

import torchvision

import torchinfo

# Setting random seeds

In [None]:
RANDOM_STATE = 42
random.seed(RANDOM_STATE)
os.environ['PYTHONHASHSEED'] = str(RANDOM_STATE)
np.random.seed(RANDOM_STATE)
torch.manual_seed(RANDOM_STATE)
torch.cuda.manual_seed(RANDOM_STATE)
torch.cuda.manual_seed_all(RANDOM_STATE)

In [None]:
if os.environ.get("KAGGLE_KERNEL_RUN_TYPE") != "":
    RUNNER = "kaggle"
elif "google.colab" in str(IPython.get_ipython()):
    RUNNER = "colab"
else:
    # assume running on a local machine
    RUNNER = "local"

# Determining device

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

# Setting globals

In [None]:
BATCH_SIZE = 8
IMAGE_SIZE = 256
NORMALIZATION_PARAMS = {"mean": (0.485, 0.456, 0.406),
                        "std": (0.229, 0.224, 0.225)}
plt.rcParams["figure.figsize"] = (15, 15)
USE_SMALLER_DATASET = False

# Helper functions and classes

In [None]:
class Denormalize():
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std
    
    def __call__(self, tensor):
        for t, m ,s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
        return tensor

class StyleTransferDataset(torch.utils.data.Dataset):
    def __init__(self, content_dir, style_dir, transforms=None):
        content_images = glob.glob(content_dir + "/*")
        style_images = glob.glob(style_dir + "/*")
        self.images = list(zip(content_images, style_images))
        self.transforms = transforms

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        content, style = self.images[idx]
        content_img = skimage.io.imread(content)
        if len(content_img.shape) < 3:
            content_img = skimage.color.gray2rgb(content_img)
        style_img = skimage.io.imread(style)
        if len(style_img.shape) < 3:
            content_img = skimage.color.gray2rgb(content_img)
        content_img = torchvision.transforms.ToTensor()(content_img)
        style_img = torchvision.transforms.ToTensor()(style_img)
        if self.transforms:
            content_img, style_img = self.transforms(content_img), self.transforms(style_img)
        return content_img, style_img

In [None]:
class AdaptiveInstanceNorm2d(torch.nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps

    def _get_mean(self, features):
        batch_size, c = features.size()[:2]
        features_mean = features.reshape(batch_size, c, -1).mean(dim=2).reshape(batch_size, c, 1, 1)
        return features_mean
    
    def _get_std(self, features):
        batch_size, c = features.size()[:2]
        features_std = features.reshape(batch_size, c, -1).std(dim=2).reshape(batch_size, c, 1, 1) + self.eps
        return features_std

    def forward(self, content, style):
        content_mean, content_std = self._get_mean(content), self._get_std(content)
        style_mean, style_std = self._get_mean(style), self._get_std(style)
        normalized = style_std * (content - content_mean) / content_std + style_mean
        return normalized

In [None]:
def fit_epoch(data_train, model, optimizer, criterion, epoch, epochs):
    model.train()
    running_loss = 0.0
    processed_data = 0
    styled = []
    for content, style in tqdm.tqdm(data_train, desc=f"Fitting epoch {epoch + 1}/{epochs}", unit="batch", unit_scale=False):
        try:
            content, style = content.to(DEVICE), style.to(DEVICE)
            optimizer.zero_grad()
            output, t = model(content, style)
            output_features = model.encoder(output, output_last=True)
            content_middle = model.encoder(output, output_last=False)
            style_middle = model.encoder(style, output_last=False)
            loss = criterion(output_features, t, content_middle, style_middle)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * content.size(0)
            processed_data += content.size(0)
            styled.append(output.detach().cpu())
        finally:
            content, style = content.cpu(), style.cpu()
            del content, style
            torch.cuda.empty_cache()
    train_loss = running_loss / processed_data
    return torch.cat(styled, dim=0), train_loss

def train_model(data_train, data_val, model, optimizer, criterion, epochs, start_epoch=0, checkpoint_cooldown=10):
    history = []
    prev_lr = optimizer.param_groups[0]["lr"]
    start_time = time.time()
    with tqdm.tqdm(desc="Epoch", total=epochs, unit="epoch", unit_scale=False) as pbar:
        for epoch in range(epochs):
            try:
                output, train_loss = fit_epoch(data_train, model, optimizer, criterion, epoch, epochs)
                IPython.display.clear_output(wait=True)
                history.append((train_loss, val_loss, optimizer.param_groups[0]["lr"]))
                show_pics_train(data_val, output, history[-1], epoch, 6)
                pbar.update(1)
                pbar.refresh()

                if (epoch + 1) % checkpoint_cooldown == 0:
                    save_model(f"nst_model_{epoch + 1}.tar", mode="training", model=model, optimizer=optimizer, loss=criterion, history=history, epoch=epoch)
            except KeyboardInterrupt as stop:
                tqdm.tqdm.write(f"Training interrupted at epoch {epoch + 1}. Returning history")
                return history
    end_time = time.time()
    train_time = end_time - start_time
    tqdm.tqdm.write(f"Overall training time: {train_time: 0.1f} seconds")
    return history

In [None]:
def show_pics_train(data_val, output, stats, epoch, sample_size):
    log_template = "Styled images on epoch {ep: 03d}.\n\
    Train loss: {t_loss: 0.4f}"
    content, style = next(iter(data_val))
    content = content[:sample_size]
    style = style[:sample_size]
    denorm = Denormalize(mean=NORMALIZATION_PARAMS["mean"], std=NORMALIZATION_PARAMS["std"])
    styled = denorm(output[:sample_size]).permute(0, 2, 3, 1)
    content = denorm(content).permute(0, 2, 3, 1)
    style = denorm(style).permute(0, 2, 3, 1)
    for i in range(sample_size):
        plt.subplot(3, sample_size, i + 1)
        plt.imshow(np.clip(content[i].squeeze().numpy(), 0, 1))
        plt.title("Content")
        plt.axis("off")

        plt.subplot(3, sample_size, i + 1 + sample_size)
        plt.imshow(np.clip(style[i].squeeze().numpy(), 0, 1))
        plt.title("Style")
        plt.axis("off")

        plt.subplot(3, sample_size, i + 1 + 2 * sample_size)
        plt.imshow(np.clip(styled[i].squeeze().numpy(), 0, 1))
        plt.title("Styled")
        plt.axis("off")
    plt.suptitle(log_template.format(ep=epoch + 1, t_loss=stats[0]))
    plt.show();

def plot_pics(pics, sample_size=6):
    for i in range(sample_size):
        plt.subplot(2, sample_size // 2 + 1, i + 1)
        image = pics[i]
        image = image.permute(1, 2, 0)
        plt.imshow(image.squeeze().numpy())
        plt.title("Images")
        plt.axis("off")
    plt.show()

def plot_loss(history):
    loss, _ = zip(*history)
    plt.figure(figsize=(15, 9))
    plt.plot(loss, label="Train loss")
    plt.legend(loc="best")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.show()

def plot_learn_rate(history):
    _, _, learn_rate = zip(*history)
    plt.figure(figsize=(15, 9))
    plt.plot(learn_rate, label="Learn rate")
    plt.legend(loc="best")
    plt.xlabel("Epochs")
    plt.ylabel("Learn rate")
    plt.show()

In [None]:
def style_transfer(model, content, style):
    model.eval()
    try:
        content, style = content.to(DEVICE), style.to(DEVICE)
        if content.dim() != 4:
            content = content.unsqueeze(0)
        if style.dim() != 4:
            style = style.unsqueeze(0)
        with torch.no_grad():
            styled, _ = model(content, style)
        styled = styled.detach().cpu()
    finally:
        content, style = content.cpu(), style.cpu()
        del content, style
        torch.cuda.empty_cache()
    return styled

In [None]:
def save_model(path, model, mode="inference", **kwargs):
    if mode == "training":
        torch.save({
            "epoch": kwargs["epoch"],
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": kwargs["optimizer"].state_dict(),
            "loss": kwargs["loss"],
            "history": kwargs["history"]},
            path)
    else:
        torch.save(model.state_dict(), path)

def load_model(path, model_arch, mode="inference", optim_class=None, optim_kwargs=None):
    if mode == "training":
        if not optim_class:
            raise ValueError("Optimizer class required to load a model saved for training.")
        model = model_arch()
        checkpoint = torch.load(path)
        model.load_state_dict(checkpoint["model_state_dict"])
        if not optim_kwargs:
            optim = optim_class(model.parameters())
        else:
            optim = optim_class(model.parameters(), **optim_kwargs)
        optim.load_state_dict(checkpoint["optim_state_dict"])
        epoch = checkpoint["epoch"]
        loss = checkpoint["loss"]
        history = checkpoint["history"]
        model.eval()
        return model, optim, epoch, loss, history
    else:
        # loading from TorchScript
        model = model_arch()
        model.load_state_dict(torch.load(path))
        model.eval()
        return model

def save_history(path, history):
    with open(path, "wb") as f:
        pickle.dump(history, f)

def load_history(path):
    with open(path, "rb") as f:
        history = pickle.load(f)
        
    return history

# Downloading data

There are different ways to download required data depending on runner type, and different requirements for those to work.

If you are:
* Running the notebook in a Kaggle session: add the `shaorrran/coco-wikiart-nst-dataset-512-100000` dataset to the session.
* Running the notebook in Google Colab: upload your `kaggle.json` file into the session.
* Running the notebook locally: have Kaggle API installed and your token ready.

In [None]:
DATASET_NAME = "shaorrran/cocowikiart-nst-dataset-small" if USE_SMALLER_DATASET \
else "shaorrran/coco-wikiart-nst-dataset-512-100000"
if RUNNER == "kaggle":
    name = DATASET_NAME.split("/")[1]
    DATASET_PATH = f"/kaggle/input/{name}"
elif RUNNER == "colab":
    from google.colab import files
    files.upload()
    os.makedirs("/root/.kaggle", exist_ok=True)
    !mv kaggle.json /root/.kaggle/kaggle.json
    !chmod 600 /root/.kaggle/kaggle.json
    !kaggle datasets download -d {DATASET_NAME} --unzip
    DATASET_PATH = os.getcwd()
else:
    if not os.exists("content.zip") or os.exists("style.zip") \
    and not (os.path.isdir("content") and os.path.isdir("style")):
        !kaggle datasets download -d {DATASET_NAME} --unzip
    else:
        with zipfile.ZipFile("content.zip", "r") as archive:
            for member in tqdm.tqdm(archive.namelist(), desc="Extracting", unit="files", unit_scale=False):
                archive.extract(member, os.getcwd())
        with zipfile.ZipFile("style.zip", "r") as archive:
            for member in tqdm.tqdm(archive.namelist(), desc="Extracting", unit="files", unit_scale=False):
                archive.extract(member, os.getcwd())
        !rm content.zip
        !rm style.zip
    DATASET_PATH = os.getcwd()

# Creating Dataset and Dataloaders

In [None]:
data = StyleTransferDataset(os.path.join(DATASET_PATH, "content"), os.path.join(DATASET_PATH, "style"), 
                            transforms=torchvision.transforms.Compose([
                                torchvision.transforms.RandomCrop(IMAGE_SIZE),
                                torchvision.transforms.Normalize(
                                mean=NORMALIZATION_PARAMS["mean"],
                                std=NORMALIZATION_PARAMS["std"],
                            )]))
data_train, data_val = torch.utils.data.random_split(data, 
                                                     [int(0.95 * len(data)), len(data) - int(0.95 * len(data))], 
                                                     generator=torch.Generator().manual_seed(RANDOM_STATE))
train_loader = torch.utils.data.DataLoader(data_train, batch_size=int(BATCH_SIZE), shuffle=True, 
                                           num_workers=torch.multiprocessing.cpu_count(), 
                                           pin_memory=True, 
                                           drop_last=True)
val_loader = torch.utils.data.DataLoader(data_val, batch_size=int(BATCH_SIZE), shuffle=False, 
                                           num_workers=torch.multiprocessing.cpu_count(), 
                                           pin_memory=True, 
                                           drop_last=True)

# Defining models

In [None]:
class VGGEncoder(torch.nn.Module):
    def __init__(self):
        super().__init__()
        vgg = torchvision.models.vgg19(pretrained=True).features
        self.slice0 = vgg[:2]
        self.slice1 = vgg[2:7]
        self.slice2 = vgg[7:12]
        self.slice3 = vgg[12:21]
        for p in self.parameters():
            p.requires_grad = False

    def to(self, device):
        # redefine _apply to move all params to one device
        new_self = super().to(device)
        for p in new_self.parameters():
            p = p.to(device)
        return new_self
        
    def forward(self, images, output_last=False):
        h0 = self.slice0(images)
        h1 = self.slice1(h0)
        h2 = self.slice2(h1)
        h3 = self.slice3(h2)
        if output_last:
            return h3
        return h0, h1, h2, h3

class Decoder(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = torch.nn.Sequential(collections.OrderedDict([
            ("conv0", torch.nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1, padding_mode="reflect")),
            ("act0", torch.nn.ReLU(True)),
            ("upsample0", torch.nn.Upsample(scale_factor=2)),
            ("conv1", torch.nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, padding_mode="reflect")),
            ("act1", torch.nn.ReLU(True)),
            ("conv2", torch.nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, padding_mode="reflect")),
            ("act2", torch.nn.ReLU(True)),
            ("conv3", torch.nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, padding_mode="reflect")),
            ("act3", torch.nn.ReLU(True)),
            ("conv4", torch.nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1, padding_mode="reflect")),
            ("act4", torch.nn.ReLU(True)),
            ("upsample1", torch.nn.Upsample(scale_factor=2)),
            ("conv5", torch.nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, padding_mode="reflect")),
            ("act5", torch.nn.ReLU(True)),
            ("conv6", torch.nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1, padding_mode="reflect")),
            ("act6", torch.nn.ReLU(True)),
            ("upsample2", torch.nn.Upsample(scale_factor=2)),
            ("conv7", torch.nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, padding_mode="reflect")),
            ("act7", torch.nn.ReLU(True)),
            ("conv8", torch.nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1, padding_mode="reflect")),
            ("act8", torch.nn.ReLU(True))
            ]))

    def forward(self, x):
        return self.layers(x)

class StyleTransferCNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = VGGEncoder()
        self.decoder = Decoder()
        self.instnorm = AdaptiveInstanceNorm2d()
    
    def _apply(self, fn):
        # redefine to move encoder and decoder to the same device the main model is on
        super()._apply(fn)
        self.encoder._apply(fn)
        self.decoder._apply(fn)
        return self

    def forward(self, content, style, alpha=1.0):
        content_features = self.encoder(content, output_last=True)
        style_features = self.encoder(style, output_last=True)
        t = alpha * self.instnorm(content_features, style_features) + (1 - alpha) * content_features
        return self.decoder(t), t

In [None]:
model = StyleTransferCNN().to(DEVICE)

In [None]:
torchinfo.summary(model, ((int(BATCH_SIZE), 3, IMAGE_SIZE, IMAGE_SIZE), (int(BATCH_SIZE), 3, IMAGE_SIZE, IMAGE_SIZE)))

In [None]:
!nvidia-smi

# Defining loss

In [None]:
class StyleTransferLoss(torch.nn.Module):
    def __init__(self, lam=10):
        super().__init__()
        self.lam = lam

    def _style_loss(self, content_middle, style_middle):
        loss = 0
        inst_norm = AdaptiveInstanceNorm2d()
        for c, s in zip(content_middle, style_middle):
            c_mean, c_std = inst_norm._get_mean(c), inst_norm._get_std(c)
            s_mean, s_std = inst_norm._get_mean(s), inst_norm._get_std(s)
            loss += F.mse_loss(c_mean, s_mean) + F.mse_loss(c_std, s_std)
        return loss

    def _content_loss(self, content, t):
        return F.mse_loss(content, t)

    def forward(self, content, t, content_middle, style_middle):
        return self._content_loss(content, t) + self.lam * self._style_loss(content_middle, style_middle)

# Training model

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
criterion = StyleTransferLoss()

In [None]:
max_epochs = 25
history_nst = train_model(train_loader, val_loader, model, optimizer, criterion, max_epochs)

In [None]:
plot_loss(history_nst)

# Validating model

In [None]:
test_content, test_style = next(iter(val_loader))
test_styled = style_transfer(model, test_content, test_style)
plot_pics(Denormalize(mean=NORMALIZATION_PARAMS["mean"], std=NORMALIZATION_PARAMS["std"])(test_styled))

# Saving model

In [None]:
save_model(f"nst_vgg+adaptive_instnorm.pth", model, mode="inference")

# Results

Retraining model on a smaller dataset with 25 epochs yields only blurred blots instead of images... guess I have to look for something else.