## Implementing Style Transfer using AdaIN

In this notebook we are going to implement Adaptive Instance Normalization based on the paper "Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization" by Huang et. al. All the implementation will be done via pytorch. So, let dive into it.

We are going to first import all the necessary packages.

In [None]:
## Importing necessary packages ##

import torch
import torchvision
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid
from torchvision.models import vgg19
from torchvision.transforms import transforms

from tqdm import tqdm
import PIL
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

Now, first up is importing the datasets.

Since this is a style transfer model, we will have a content image and a style image. Hence, we are going to load two datasets. So, lets go and do that in the next step.

In [None]:
## Making our custom dataset ##

class CustomDataset(Dataset):
    
    def __init__(self , content_dir , style_dir):
        
        super().__init__()
        
        self.content_dir = content_dir
        
        self.content_imgs = os.listdir(content_dir)
        
        self.style_dir = style_dir
        
        self.style_imgs = os.listdir(style_dir)
        
    def __getitem__(self , idx):
        
        content_idx = idx % len(self.content_imgs)
        
        style_idx = idx % len(self.style_imgs)
        
        content_arr =  PIL.Image.open(os.path.join(self.content_dir , self.content_imgs[content_idx]))
        
        style_arr =  PIL.Image.open(os.path.join(self.style_dir , self.style_imgs[style_idx]))
        
        aug = transforms.Compose([
            transforms.Resize((256 , 256)),
            transforms.ToTensor()
        ])
        
        content_tensor = aug(content_arr)
        
        style_tensor = aug(style_arr)
        
        return content_tensor , style_tensor
    
    def __len__(self):
        
        content_length = len(self.content_imgs)
        
        return content_length 

## Defining our dataset instance ##

dataset = CustomDataset('../input/style-content-data/mscoco_resized/train2014','../input/style-content-data/Abstract_gallery')

As always it's better to visualize a datapoint. We are going to do that next.

In [None]:
random_idx = int(np.random.randint(low = 0 , high = len(dataset) , size = 1))

content , style = dataset[random_idx]

print('Content Image --> Maximum value : {} , Minimum value : {}'.format(torch.max(content) , torch.min(content)))

print('Style Image --> Maximum value : {} , Minimum value : {}'.format(torch.max(style) , torch.min(style)))

plt.imshow(content.permute(1 , 2 , 0))

plt.show()

plt.imshow(style.permute(1 , 2 , 0))

plt.show()

Our dataset is fine.

Lets now go on and make a dataloader which would feed in data to the model in batches. Each batch will contain 16 data points. Also we are going to visualize the data by using an utility function.

In [None]:
## Making our dataloader ##

st_dataloader = DataLoader(dataset , batch_size = 32 , shuffle = True)


## Visualization utility function ##

def visualize(imgs):
    
    fig , ax = plt.subplots(figsize = (8 , 4))
    
    ax.set_xticklabels([])
    
    ax.set_yticklabels([])
    
    plt.imshow(make_grid(imgs.detach().to('cpu') , 4).permute(1 , 2 , 0))
    
    plt.show()
    

## Visualizing a batch of data ##

for content , style in st_dataloader:
    
    visualize(content)
    
    visualize(style)
    
    break

With that we have set up our dataset as well as dataloader.

But the thing is we have utilized the GPU of our machine. All our dataloader batches will be in cpu. To transfer the data to GPU we need to do some more work. Lets do that now!

In [None]:
## Setting the device to cpu or gpu ##

def set_device():
    
    if torch.cuda.is_available():
        
        return torch.device('cuda')
    
    return torch.device('cpu')

device = set_device()


## Utility function to transfer data to the specified device ##

def transfer_data(data , device):
    
    if isinstance(data , (tuple , list)):
        
        return [transfer_data(each_data , device) for each_data in data]
    
    return data.to(device)


## GPU dataloader class ##

class GPUDataloader:
    
    def __init__(self , dl , device):
        
        self.dl = dl
        
        self.device = device
        
    def __iter__(self):
        
        for batch in self.dl:
            
            yield transfer_data(batch , self.device)
            
    def __len__(self):
        
        return len(self.dl)

Now with all that out of the way, lets create an instance of the dataloader and as always visualize the data.

In [None]:
## Creating dataloader instance ##

styletransfer_dl = GPUDataloader(st_dataloader , device)


## Visualizing a batch of dataloader ##

for content , style in styletransfer_dl: 

    visualize(content)
    
    visualize(style)
    
    break

With the initial things done, lets get our hands dirty and model our neural network.

First off, lets set our encoder.

In [None]:
## Setting our encoder part ##

class Encoder(nn.Module):
    
    def __init__(self):
        
        super().__init__()
        
        self.model = vgg19(pretrained = True , progress= True)
                
        self.features = self.model.features[:21]
        
        for param in self.features.parameters():
            
            param.requires_grad = False
        
    def forward(self , x):
        
        out = self.features(x)
        
        return out
    
## Testing our encoder ##

test_img = torch.randn(1 , 3 , 256 , 256)

test_encoder = Encoder()

test_out = test_encoder(test_img)

print('The output shape is :' , test_out.shape)

Our Encoder is working perfectly then!

Next up lets build our Adaptive Instance Normalization module.

In [None]:
## Creating AdaIN ##

class AdaIN(nn.Module):
    
    def _init__(self):
        
        super().__init__()
        
    def forward(self , content , style):
        
        content_mean = torch.mean(content , dim = [2 , 3] , keepdims = False).unsqueeze(2).unsqueeze(3).expand_as(content)
        
        content_std = torch.std(content , dim = [2 , 3] , keepdims = False).unsqueeze(2).unsqueeze(3).expand_as(content)
        
        style_mean = torch.mean(style , dim = [2 , 3] , keepdims = False).unsqueeze(2).unsqueeze(3).expand_as(content)
        
        style_std = torch.std(style , dim = [2 , 3] , keepdims = False).unsqueeze(2).unsqueeze(3).expand_as(content)
        
        out = style_std * ((content - content_mean) / (content_std + 1e-5)) + style_mean
        
        return out , style_mean , style_std
    
    
## Testing our AdaIN ##

test_content = torch.randn(1 , 512 , 32 , 32)

test_style = torch.randn(1 , 512 , 32 , 32)

test_adain = AdaIN()

test_out , _ , _ = test_adain(test_content , test_style)

print('Output Shape :' , test_out.shape)

Now we are going to make our VGG19 decoder.

-- Replace the Pooling with upsampling.

In [None]:
## Creating the VGG19 decoder module ##

class Decoder(nn.Module):
    
    def __init__(self , in_dim = 512 , out_dim = 3 , kernel = 3, k_stride = 1 , 
                 padding = 1 , scale = 2):
        
        super().__init__()
        
        self.net = nn.Sequential(nn.Conv2d(in_channels = in_dim , out_channels = in_dim // 2 , kernel_size = kernel,
                                           stride = k_stride , padding = padding),
                                 nn.Upsample(scale_factor = scale),
                                 nn.ReLU(),
                                 nn.Conv2d(in_channels = in_dim // 2 , out_channels = in_dim // 2 , kernel_size = kernel,
                                           stride = k_stride , padding = padding),
                                 nn.ReLU(),
                                 nn.Conv2d(in_channels = in_dim // 2 , out_channels = in_dim // 2 , kernel_size = kernel,
                                           stride = k_stride , padding = padding),
                                 nn.ReLU(),
                                 nn.Conv2d(in_channels = in_dim // 2 , out_channels = in_dim // 2 , kernel_size = kernel,
                                           stride = k_stride , padding = padding),
                                 nn.ReLU(),
                                 nn.Conv2d(in_channels = in_dim // 2 , out_channels = in_dim // 4 , kernel_size = kernel,
                                           stride = k_stride , padding = padding),
                                 nn.Upsample(scale_factor = scale),
                                 nn.ReLU(),
                                 nn.Conv2d(in_channels = in_dim // 4 , out_channels = in_dim // 4 , kernel_size = kernel,
                                           stride = k_stride , padding = padding),
                                 nn.ReLU(),
                                 nn.Conv2d(in_channels = in_dim // 4 , out_channels = in_dim // 8 , kernel_size = kernel,
                                           stride = k_stride , padding = padding),
                                 nn.Upsample(scale_factor = scale),
                                 nn.ReLU(),
                                 nn.Conv2d(in_channels = in_dim // 8 , out_channels = in_dim // 8 , kernel_size = kernel,
                                           stride = k_stride , padding = padding),
                                 nn.ReLU(),
                                 nn.Conv2d(in_channels = in_dim // 8 , out_channels = out_dim , kernel_size = kernel,
                                           stride = k_stride , padding = padding)
                                )
        
    def forward(self , x):
        
        out = self.net(x)
        
        return out
    

## Testing our decoder ##

test_inp = torch.randn(1 , 512 , 32 , 32)

test_decoder = Decoder()

test_out = test_decoder(test_inp)

print('Shape of output :' , test_out.shape)

All right.

Our modules are done.

Now lets put all of them together and make our model.

In [None]:
## Joining the pieces and making our final model ##

class StyleTransfer(nn.Module):
    
    def __init__(self):
        
        super().__init__()
        
        self.encoder = Encoder()
        
        self.decoder = Decoder()
        
        self.adain = AdaIN()
        
    def forward(self , content , style):
        
        content_feature = self.encoder(content)
        
        style_feature = self.encoder(style)
        
        stylized_feature , style_mean , style_std = self.adain(content_feature , style_feature)
        
        stylized_img = self.decoder(stylized_feature)
        
        content_final_map = self.encoder(stylized_img)
        
        final_mean = torch.mean(content_final_map , dim = [2 , 3] , 
                                keepdims = False).unsqueeze(2).unsqueeze(3).expand_as(style_mean)
        
        final_std = torch.std(content_final_map , dim = [2 , 3] , 
                              keepdims = False).unsqueeze(2).unsqueeze(3).expand_as(style_std)
        
        return stylized_img , content_final_map , stylized_feature , style_mean , style_std , final_mean , final_std
    
## Testing our model ##

test_content = torch.randn(1 , 3 , 256 , 256)

test_style = torch.randn(1 , 3 , 256 , 256)

test_model = StyleTransfer()

stylized_img , content_final_map , stylized_feature , style_mean , style_std , final_mean , final_std = test_model(test_content
                                                                                                                   , test_style)

print('Stylized Image shape :' , stylized_img.shape)

print('Content final map shape :' , content_final_map.shape)

print('Stylized map shape :' , stylized_feature.shape)

print('Style mean shape :' , style_mean.shape)

print('Style std shape :' , style_std.shape)

print('Final mean shape :' , final_mean.shape)

print('Final std shape :' , final_std.shape)

That is correct.

Finally lets set our model.

In [None]:
## Creating our model object ##

model = StyleTransfer()

model = transfer_data(model , device)

Now lets set our loss function. We are going to use the Mean Squared Error loss function. So lets define that.

In [None]:
## Defining loss function ##

mse_loss = nn.MSELoss()

## Checking loss ##

test_lc = mse_loss(content_final_map , stylized_feature)

test_ls = mse_loss(style_mean , final_mean) + mse_loss(style_std , final_std)

total_l = test_lc + 0.01 * test_ls

print('Content Loss : {}\nStyle Loss : {}\nTotal Loss : {}'.format(test_lc.item() , test_ls.item() , total_l.item()))

So, everything is working fine.

Now let's define our optimizer.

In [None]:
## Setting our optimizer ##

optim = torch.optim.Adam(model.parameters() , lr = 1e-3)

Done.

All we do now is train and watch.

In [None]:
## Training our model ##

num_epochs = 10

loop = tqdm(styletransfer_dl)

for epoch in range(num_epochs):
    
    for content , style in loop:
        
        stylized_img , content_final_map , stylized_feature , style_mean , style_std , final_mean , final_std = model(content ,
                                                                                                                      style)
        
        lc = mse_loss(content_final_map , stylized_feature)

        ls = mse_loss(style_mean , final_mean) + mse_loss(style_std , final_std)

        total_loss = lc + 0.1 * ls
        
        optim.zero_grad()
        
        total_loss.backward()
        
        optim.step()
        
    print('Epoch : {} / {} --> Loss is {:.2f}'.format(epoch + 1 , num_epochs , total_loss.item()))
    
    visualize(content)
    
    visualize(style)
    
    visualize(stylized_img)