This notebook is a direct translation of: 
https://github.com/bethgelab/stylize-datasets
to simplify the stylization of images in our use case. 

We simply combined and adapted their scripts slightly into a runnable notebook to acceleration iterations of our experiments.

## Get dataset

In [0]:
import os

if not os.path.isdir('./imagenette2-160.zip'):
  import gdown
  url = 'https://drive.google.com/uc?id=11MOFZF2dVjEu0PbSPmGVWbb2olSZ8eHV'
  output = 'imagenette2-160.zip'
  gdown.download(url, output, quiet=True)
  !unzip "imagenette2-160.zip" -d .
  %rm imagenette2-160.zip

## Get textures

In [0]:
import os

if not os.path.isdir('./best-textures.zip'):
  import gdown
  url = 'https://drive.google.com/uc?id=1bGHUaL7DzFzVsJKYv04o4ZdNZ9zCTAhY'
  output = 'best-textures.zip'
  gdown.download(url, output, quiet=True)
  !unzip "best-textures.zip" -d .
  %rm best-textures.zip

## Options

In [0]:
from pathlib import Path

# Directory path to a batch of content images
# CONTENT_DIR = Path('./imagenette2-160').resolve()
CONTENT_DIR = Path('./imagenette2-160').resolve()

# Directory path to a batch of style images
STYLE_DIR = Path('./best-textures').resolve()

# Directory to save the output images
OUTPUT_DIR = Path('./stylized-imagenette2-160').resolve()

# Number of styles to create for each image
NUM_STYLE = 1

# The weight that controls the degree of stylization. Should be between 0 and 1.
ALPHA = 1.0

# List of image extensions to scan style and content directory for.
EXTENSIONS = ['JPEG', 'jpg']

# New (minimum) size for the content image, keeping the original size if 0
CONTENT_SIZE = 0

# New (minimum) size for the style image, keeping the original size if 0
STYLE_SIZE = 512

# CROP SIZE
CROP = 0

assert CONTENT_DIR.is_dir(), 'Content directory not found'
assert STYLE_DIR.is_dir(), 'Style directory not found'
assert len(EXTENSIONS) > 0, 'No file extensions specified'

## Get models

In [0]:
!mkdir models
%cd models/

In [0]:
!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=108uza-dsmwvbW2zv-G73jtVcMU_2Nb7Y' -O vgg_normalised.pth
!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1w9r1NoYnn7tql1VYG3qDUzkbIks24RBQ' -O decoder.pth
%cd ..

## function.py

In [0]:
import torch

def calc_mean_std(feat, eps=1e-5):
    # eps is a small value added to the variance to avoid divide-by-zero.
    size = feat.data.size()
    assert (len(size) == 4)
    N, C = size[:2]
    feat_var = feat.view(N, C, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(N, C, 1, 1)
    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
    return feat_mean, feat_std

# adaptive_instance_normalization
def adain(content_feat, style_feat):
    assert (content_feat.data.size()[:2] == style_feat.data.size()[:2])
    size = content_feat.data.size()
    style_mean, style_std = calc_mean_std(style_feat)
    content_mean, content_std = calc_mean_std(content_feat)

    normalized_feat = (content_feat - content_mean.expand(
        size)) / content_std.expand(size)
    return normalized_feat * style_std.expand(size) + style_mean.expand(size)


def _calc_feat_flatten_mean_std(feat):
    # takes 3D feat (C, H, W), return mean and std of array within channels
    assert (feat.size()[0] == 3)
    assert (isinstance(feat, torch.FloatTensor))
    feat_flatten = feat.view(3, -1)
    mean = feat_flatten.mean(dim=-1, keepdim=True)
    std = feat_flatten.std(dim=-1, keepdim=True)
    return feat_flatten, mean, std


def _mat_sqrt(x):
    U, D, V = torch.svd(x)
    return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t())


def coral(source, target):
    # assume both source and target are 3D array (C, H, W)
    # Note: flatten -> f

    source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source)
    source_f_norm = (source_f - source_f_mean.expand_as(
        source_f)) / source_f_std.expand_as(source_f)
    source_f_cov_eye = \
        torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3)

    target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target)
    target_f_norm = (target_f - target_f_mean.expand_as(
        target_f)) / target_f_std.expand_as(target_f)
    target_f_cov_eye = \
        torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3)

    source_f_norm_transfer = torch.mm(
        _mat_sqrt(target_f_cov_eye),
        torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)),
                 source_f_norm)
    )

    source_f_transfer = source_f_norm_transfer * \
                        target_f_std.expand_as(source_f_norm) + \
                        target_f_mean.expand_as(source_f_norm)

    return source_f_transfer.view(source.size())

## net.py

In [0]:
import torch.nn as nn
from torch.autograd import Variable

net_decoder = nn.Sequential(
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 256, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 128, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 128, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 64, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 64, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 3, (3, 3)),
)

net_vgg = nn.Sequential(
    nn.Conv2d(3, 3, (1, 1)),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(3, 64, (3, 3)),
    nn.ReLU(),  # relu1-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 64, (3, 3)),
    nn.ReLU(),  # relu1-2
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 128, (3, 3)),
    nn.ReLU(),  # relu2-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 128, (3, 3)),
    nn.ReLU(),  # relu2-2
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 256, (3, 3)),
    nn.ReLU(),  # relu3-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),  # relu3-2
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),  # relu3-3
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),  # relu3-4
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 512, (3, 3)),
    nn.ReLU(),  # relu4-1, this is the last layer used
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu4-2
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu4-3
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu4-4
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu5-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu5-2
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu5-3
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU()  # relu5-4
)


class Net(nn.Module):
    def __init__(self, encoder, decoder):
        super(Net, self).__init__()
        enc_layers = list(encoder.children())
        self.enc_1 = nn.Sequential(*enc_layers[:4])  # input -> relu1_1
        self.enc_2 = nn.Sequential(*enc_layers[4:11])  # relu1_1 -> relu2_1
        self.enc_3 = nn.Sequential(*enc_layers[11:18])  # relu2_1 -> relu3_1
        self.enc_4 = nn.Sequential(*enc_layers[18:31])  # relu3_1 -> relu4_1
        self.decoder = decoder
        self.mse_loss = nn.MSELoss()

    # extract relu1_1, relu2_1, relu3_1, relu4_1 from input image
    def encode_with_intermediate(self, input):
        results = [input]
        for i in range(4):
            func = getattr(self, 'enc_{:d}'.format(i + 1))
            results.append(func(results[-1]))
        return results[1:]

    # extract relu4_1 from input image
    def encode(self, input):
        for i in range(4):
            input = getattr(self, 'enc_{:d}'.format(i + 1))(input)
        return input

    def calc_content_loss(self, input, target):
        assert (input.data.size() == target.data.size())
        assert (target.requires_grad is False)
        return self.mse_loss(input, target)

    def calc_style_loss(self, input, target):
        assert (input.data.size() == target.data.size())
        assert (target.requires_grad is False)
        input_mean, input_std = calc_mean_std(input)
        target_mean, target_std = calc_mean_std(target)
        return self.mse_loss(input_mean, target_mean) + \
               self.mse_loss(input_std, target_std)

    def forward(self, content, style):
        style_feats = self.encode_with_intermediate(style)
        t = adain(self.encode(content), style_feats[-1])

        g_t = self.decoder(Variable(t.data, requires_grad=True))
        g_t_feats = self.encode_with_intermediate(g_t)

        loss_c = self.calc_content_loss(g_t_feats[-1], t)
        loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0])
        for i in range(1, 4):
            loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i])
        return loss_c, loss_s

## stylise.py

In [0]:
from PIL import Image
import random
import torch
import torch.nn as nn
import torchvision.transforms
from torchvision.utils import save_image
from tqdm import tqdm

def input_transform(size, crop):
    transform_list = []
    if size != 0: 
        transform_list.append(torchvision.transforms.Resize(size))
    if crop != 0:
        transform_list.append(torchvision.transforms.CenterCrop(crop))
    transform_list.append(torchvision.transforms.ToTensor())
    transform = torchvision.transforms.Compose(transform_list)
    return transform

def style_transfer(vgg, decoder, content, style, alpha=1.0):
    assert (0.0 <= alpha <= 1.0)
    content_f = vgg(content)
    style_f = vgg(style)
    feat = adain(content_f, style_f)
    feat = feat * alpha + content_f * (1 - alpha)
    return decoder(feat)

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

def extract_cifar(path):
    dataset = []
    files = list(path.rglob('*'))
    for file in files:
        dataset.append(unpickle(file))
    return dataset

def main():
    dataset = []
    for ext in EXTENSIONS:
        dataset += list(CONTENT_DIR.rglob('*.' + ext))

    assert len(dataset) > 0, 'No images with specified extensions found in content directory' + CONTENT_DIR
    content_paths = sorted(dataset)
    print('Found %d content images in %s' % (len(content_paths), CONTENT_DIR))

    # collect style files
    styles = []
    for ext in EXTENSIONS:
        styles += list(STYLE_DIR.rglob('*.' + ext))

    assert len(styles) > 0, 'No images with specified extensions found in style directory' + STYLE_DIR
    styles = sorted(styles)
    print('Found %d style images in %s' % (len(styles), STYLE_DIR))

    decoder = net_decoder
    vgg = net_vgg

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    decoder.eval()

    vgg.eval()

    decoder.load_state_dict(torch.load('models/decoder.pth'))
    vgg.load_state_dict(torch.load('models/vgg_normalised.pth'))
    vgg = nn.Sequential(*list(vgg.children())[:31])

    vgg.to(device)
    decoder.to(device)

    content_tf = input_transform(CONTENT_SIZE, CROP)
    style_tf = input_transform(STYLE_SIZE, 0)


    # disable decompression bomb errors
    Image.MAX_IMAGE_PIXELS = None
    skipped_imgs = []
    
    # actual style transfer as in AdaIN
    with tqdm(total=len(content_paths)) as pbar:
        for content_path in content_paths:
            try:
                content_img = Image.open(content_path).convert('RGB')
                for style_path in random.sample(styles, NUM_STYLE):
                    style_img = Image.open(style_path).convert('RGB')

                    content = content_tf(content_img)
                    style = style_tf(style_img)
                    style = style.to(device).unsqueeze(0)
                    content = content.to(device).unsqueeze(0)
                    with torch.no_grad():
                        output = style_transfer(vgg, decoder, content, style,
                                                ALPHA)
                    output = output.cuda()

                    rel_path = content_path.relative_to(CONTENT_DIR)
                    out_dir = OUTPUT_DIR.joinpath(rel_path.parent)

                    # create directory structure if it does not exist
                    if not out_dir.is_dir():
                        out_dir.mkdir(parents=True)

                    content_name = content_path.stem
                    style_name = style_path.stem
                    out_filename = content_name + '-stylized-' + style_name + content_path.suffix
                    output_name = out_dir.joinpath(out_filename)

                    save_image(output, output_name, padding=0) #default image padding is 2.
                    style_img.close()
                content_img.close()
            except OSError as e:
                print('Skipping stylization of %s due to an error' %(content_path))
                skipped_imgs.append(content_path)
                continue
            except RuntimeError as e:
                print('Skipping stylization of %s due to an error' %(content_path))
                skipped_imgs.append(content_path)
                continue
            finally:
                pbar.update(1)
            
    if(len(skipped_imgs) > 0):
        with open(OUTPUT_DIR.joinpath('skipped_imgs.txt'), 'w') as f:
            for item in skipped_imgs:
                f.write("%s\n" % item)
main()

In [0]:
!zip -r ./output.zip ./output