# Imports

In [None]:
!pip install torchsummary

In [None]:
import glob
import os
import random
import time

import IPython.display

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
%matplotlib inline

import skimage
import skimage.io
import skimage.transform

import tqdm.auto as tqdm

import torch
import torch.nn.functional as torch.F

import torchvision

import torchsummary

# Setting random seeds

In [None]:
RANDOM_STATE = 42
random.seed(RANDOM_STATE)
os.environ['PYTHONHASHSEED'] = str(RANDOM_STATE)
np.random.seed(RANDOM_STATE)
torch.manual_seed(RANDOM_STATE)
torch.cuda.manual_seed(RANDOM_STATE)
torch.cuda.manual_seed_all(RANDOM_STATE)

# Determining device

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

# Setting globals

In [None]:
BATCH_SIZE = 512 # @param [512] {allow-input: true}
IMAGE_SIZE = 128 # @param [128] {allow-input: true}
NORMALIZATION_PARAMS = {"mean": (0.485, 0.456, 0.406),
                        "std": (0.229, 0.224, 0.225)}

# Helper functions and classes

In [None]:
def denormalize_image(image):
    return np.clip(image * NORMALIZATION_PARAMS["std"] + NORMALIZATION_PARAMS["mean"], 0, 1)

class PreprocessedDataset(torchvision.Dataset):
    def __init__(self, content_dir, style_dir, transforms=None):
        content_dir_resized = content_dir + "_resized"
        style_dir_resized = style_dir + "_resized"
        if not (os.path.exists(content_dir_resized) and os.path.exists(content_dir_resized)):
            os.mkdir(content_dir_resized)
            os.mkdir(style_dir_resized)
            self._resize(content_dir, content_dir_resized)
            self._resize(style_dir, style_dir_resized)
        content_images = glob.glob(content_dir_resized + "/*")
        style_images = glob.glob(style_dir_resized + "/*")
        self.images = list(zip(content_images, style_images))
        self.transforms = transforms

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        content, style = self.images[idx]
        content_img = torch.Tensor(skimage.io.imread(content))
        style_img = torch.Tensor(skimage.io.imread(style))
        if self.transforms:
            content_img, style_img = self.transforms(content_img), self.transforms(style_img)
        return content_img, style_img

    @staticmethod
    def _resize(source, target):
        for i tqdm.tqdm(os.listdir(source), desc="Resizing images", unit="images", unit_scale=False):
            try:
                image = skimage.io.imread(os.path.join(source, os.path.basename(i)))
                resized = skimage.transform.resize(image, (IMAGE_SIZE, IMAGE_SIZE), anti_aliasing=True)
                skimage.io.imsave(os.path.join(target, os.path.basename(i)), resized)
            except:
                continue    

In [None]:
class AdaptiveInstanceNorm2d(torch.nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps

    def _get_mean(self, features):
        batch_size, c = features.size()[:2]
        features_mean = features.reshape(batch_size, c, -1).mean(dim=2).reshape(batch_size, c, 1, 1)
        return features_mean
    
    def _get_std(self, features):
        batch_size, c = features.size()[:2]
        features_std = features.reshape(batch_size, c, -1).std(dim=2).reshape(batch_size, c, 1, 1) + self.eps
        return features_std

    def forward(self, content, style):
        content_mean, content_std = self._get_mean(content), self._get_std(content)
        style_mean, style_std = self._get_mean(style), self._get_std(style)
        normalized = style_std * (content - content_mean) / content_std + style_mean
        return normalized

In [None]:
def fit_epoch(data_train, model, optimizer, criterion, epoch, epochs):
    model.train()
    running_loss = 0.0
    processed_data = 0
    for content, style in tqdm.tqdm(data_train, desc=f"Fitting epoch {epoch}/{epochs}", unit="batches", unit_scale=False):
        try:
            content, style = content.to(DEVICE), style.to(DEVICE)
            optimizer.zero_grad()
            output, t = model(content, style)
            loss = criterion(output, style, t)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * content.size(0)
            processed_data += content.size(0)
        finally:
            content, style = content.cpu(), style.cpu()
            del content, style
            torch.cuda.empty_cache()
        train_loss = running_loss / processed_data
        return train_loss

def eval_epoch(data_val, model, criterion):
    model.eval()
    running_loss = 0.0
    processed_data = 0

    for content, style in tqdm.tqdm(data_val, desc="Validating", unit="batches", unit_scale=False):
        try:
            content, style = content.to(DEVICE), style.to(DEVICE)
            with torch.no_grad():
                styled, t = model.generate(content, style)
                loss = criterion(styled, style, t)

            running_loss = loss.item() * content.size(0)
            processed_data += content.size(0)
        finally:
            content, style = content.cpu(), style.cpu()
            del content, style
            torch.cuda.empty_cache()
    return styled, running_loss / processed_data

def train_model(data_train, data_val, model, optimizer, criterion, epochs, batch_size, scheduler=None, start_epoch=0, checkpoint_cooldown=10):
    history = []
    log_template = "\nEpoch {ep: 03d} train_loss: {t_loss: 0.4f}"
    prev_lr = optimizer.param_groups[0]["lr"]
    start_time = time.time()
    with tqdm.tqdm(desc="Epoch", total=epochs, unit="epoch", unit_scale=False) as pbar:
        for epoch in range(epochs):
            try:
                train_loss = fit_epoch(data_train, model, optimizer, criterion, epoch, epochs)
                _, val_loss = eval_epoch(data_val, model, criterion)
                if scheduler is not None:
                    scheduler.step(val_loss)
                IPython.display.clear_output(wait=True)
                history.append((train_loss, val_loss, optimizer.param_groups[0]["lr"]))
                show_pics_train(data_val, model, history[-1], epoch, epochs)
                pbar.update(1)
                pbar.refresh()

                if (epoch + 1) % checkpoint_cooldown == 0:
                    save_model(mode="training", model=model, optimizer=optimizer, loss=criterion, history=history)
            except KeyboardInterrupt as stop:
                tqdm.tqdm.write(f"Training interrupted at epoch {epoch + 1}. Returning history")
                return history
    end_time = time.time()
    train_time = end_time - start_time
    tqdm.tqdm.write(f"Overall training time: {train_time: 0.1f} seconds")
    return history

In [None]:
def show_pics_train(data_val, model, history, epoch, epochs, sample_size):
    log_template = "Styled images on epoch {ep: 03d}/{epochs: 03d}.\n\
    Train loss: {t_loss: 0.4f}, validation loss: {v_loss: 0.4f}"
    try:
        nrow = int(np.ceil(np.sqrt(sample_size)))
        styled, _ = 