In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import vgg19 
import torchvision.transforms as transforms

from PIL import Image
import numpy as np 
from tqdm import tqdm 
from pathlib import Path
import matplotlib.pyplot as plt

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Create model classes

In [2]:
# Make conv layer class for easy writing in next class
class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1):
        super().__init__()
        # We have to keep the image size same
        num_pad = int(np.floor(kernel_size / 2))
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=num_pad)
    def forward(self, x):
        return self.conv(x)

In [3]:
class BottleneckBlock(nn.Module):
    """
    Bottleneck layer similar to resnet bottleneck layer. InstanceNorm is used
    instead of BatchNorm because when we want to generate images, we normalize
    all the images independently. 
    
    (In batch norm you compute mean and std over complete batch, while in instance 
    norm you compute mean and std of each image independently). The reason for 
    doing this is, the generated images are independent of each other, so we should
    not normalize them using a common statistic.
    
    If you confused about the bottleneck architecture refer to the official pytorch
    resnet implementation and paper.
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1):
        super().__init__()
        self.in_c = in_channels
        self.out_c = out_channels
        
        self.identity_block = nn.Sequential(
            ConvLayer(in_channels, out_channels//4, kernel_size=1, stride=1),
            nn.InstanceNorm2d(out_channels//4),
            nn.ReLU(),
            ConvLayer(out_channels//4, out_channels//4, kernel_size, stride=stride),
            nn.InstanceNorm2d(out_channels//4),
            nn.ReLU(),
            ConvLayer(out_channels//4, out_channels, kernel_size=1, stride=1),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(),
        )
        
        self.shortcut = nn.Sequential(
            ConvLayer(in_channels, out_channels, 1, stride),
            nn.InstanceNorm2d(out_channels),
        )
    
    def forward(self, x):
        out = self.identity_block(x)
        if self.in_c == self.out_c:
            residual = x
        else:
            residual = self.shortcut(x)
        out += residual
        out = F.relu(out)
        return out

In [4]:
# Not used in the implementation
class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor, mode='bilinear'):
        super().__init__()
        self.scale_factor = scale_factor
        self.mode = mode
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
        self.norm = nn.InstanceNorm2d(out_channels)
        
    def forward(self, x):
        out = self.conv(x)
        out = F.interpolate(out, scale_factor=self.scale_factor, mode=self.mode, align_corners=False)
        out = self.norm(out)
        out = F.relu(out)
        return out

In [5]:
# Helper functions for HRNet
def conv_down(in_c, out_c, stride=2):
    return nn.Conv2d(in_c, out_c, kernel_size=3, stride=stride, padding=1)

def upsample(scale_factor):
    return nn.Upsample(scale_factor=scale_factor, mode='bilinear')

In [6]:
class HRNet(nn.Module):
    """
    For model reference see Figure 2 of the paper https://arxiv.org/pdf/1904.11617v1.pdf.
    
    Naming convention used.
    I refer to vertical layers as a single layer, so from left to right we have 8 layers
    excluding the input image.
    E.g. layer 1 contains the 500x500x16 block
         layer 2 contains 500x500x32 and 250x250x32 blocks and so on
    
    self.layer{x}_{y}:
        x :- the layer number, as explained above
        y :- the index number for that function starting from 1. So if layer 3 has two
             downsample functions I write them as `downsample3_1`, `downsample3_2`
    """
    def __init__(self):
        super().__init__()
        self.layer1_1 = BottleneckBlock(3, 16)
        
        self.layer2_1 = BottleneckBlock(16, 32)
        self.downsample2_1 = conv_down(16, 32)
        
        self.layer3_1 = BottleneckBlock(32, 32)
        self.layer3_2 = BottleneckBlock(32, 32)
        self.downsample3_1 = conv_down(32, 32)
        self.downsample3_2 = conv_down(32, 32, stride=4)
        self.downsample3_3 = conv_down(32, 32)
        
        self.layer4_1 = BottleneckBlock(64, 64)
        self.layer5_1 = BottleneckBlock(192, 64)
        self.layer6_1 = BottleneckBlock(64, 32)
        self.layer7_1 = BottleneckBlock(32, 16)
        self.layer8_1 = conv_down(16, 3, stride=1)
        
    def forward(self, x):
        map1_1 = self.layer1_1(x)
        
        map2_1 = self.layer2_1(map1_1)
        map2_2 = self.downsample2_1(map1_1)
        
        map3_1 = torch.cat((self.layer3_1(map2_1), upsample(map2_2, 2)), 1)
        map3_2 = torch.cat((self.downsample3_1(map2_1), self.layer3_2(map2_2)), 1)
        map3_3 = torch.cat((self.downsample3_2(map2_1), self.downsample3_3(map2_2)), 1)
        
        map4_1 = torch.cat((self.layer4_1(map3_1), upsample(map3_2, 2), upsample(map3_3, 4)), 1)
        
        out = self.layer5_1(map4_1)
        out = self.layer6_1(out)
        out = self.layer7_1(out)
        out = self.layer8_1(out)
        
        return out

## Create utility functions for image loading

In [7]:
def load_image(path, size=None):
    """
    Resize img to size, size should be int and also normalize the
    image using imagenet_stats
    """
    img = Image.open(path)

    if size is not None:
        img = img.resize((size, size))
    
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    
    img = transform(img).unsqueeze(0)
    return img

In [8]:
def im_convert(img):
    """
    Convert img from pytorch tensor to numpy array, so we can plot it.
    It follows the standard method of denormalizing the img and clipping
    the outputs
    
    Input:
        img :- (batch, channel, height, width)
    Output:
        img :- (height, width, channel)
    """
    img = img.to('cpu').clone().detach()
    img = img.numpy().squeeze(0)
    img = img.transpose(1, 2, 0)
    img = img * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    img = img.clip(0, 1)
    return img

In [9]:
def get_features(img, model, layers=None):
    """
    Use VGG19 to extract features from the intermediate layers.
    """
    if layers is None:
        layers = {
            '0': 'conv1_1',   # style layer
            '5': 'conv2_1',   # style layer
            '10': 'conv3_1',  # style layer
            '19': 'conv4_1',  # style layer
            '28': 'conv5_1',  # style layer
            
            '21': 'conv4_2'   # content layer
        }
    
    features = {}
    x = img
    for name, layer in model._modules.items():
        x = layer(x)
        if name in layers:
            features[layers[name]] = x
            
    return features

In [10]:
def get_gram_matrix(img):
    """
    Compute the gram matrix by converting to 2D tensor and doing dot product
    img: (batch, channel, height, width)
    """
    b, c, h, w = img.size()
    img = img.view(b*c, h*w)
    gram = torch.mm(img, img.t())
    return gram

## Write style_transfer.py

Refer to train_model.ipynb for continuation of this notebook

In [11]:
# For data, place your images in the img folder and name it as content.png and style.png
# You can also input your images directly (for .py script)

In [12]:
class Args:
    def __init__(self):
        self.img_root = 'src/imgs'
        self.content_img = 'content.png'
        self.style_img = 'style.png'
        self.use_batch = False
        self.bs = 16
        self.use_gpu = True
        
args = Args()

In [13]:
if args.use_gpu:
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        raise Exception('GPU is not available')
else:
    device = torch.device('cpu')
    
# Load VGG19 features
vgg = vgg19(pretrained=True).features
vgg = vgg.to(device)
# We don't want to train VGG
for param in vgg.parameters():
    param.requires_grad_(False)
    
# Load style net
style_net = HRNet()
style_net = style_net.to(device)

torch.backends.cudnn.benchmark = True

In [14]:
import os
content_img = load_image(os.path.join(args.img_root, args.content_img), size=500)
content_img = content_img.to(device)

style_img = load_image(os.path.join(args.img_root, args.style_img))
style_img = style_img.to(device)

In [15]:
content_img.size(), style_img.size()

(torch.Size([1, 3, 500, 500]), torch.Size([1, 3, 800, 800]))