In [None]:
from google.colab import drive
drive.mount('/content/gdrive/')

In [2]:
import os
if not os.path.exists("/content/gdrive/My Drive/Colab Notebooks/StyleTransfer"):
    os.makedirs("/content/gdrive/My Drive/Colab Notebooks/StyleTransfer")
os.chdir("/content/gdrive/My Drive/Colab Notebooks/StyleTransfer")

In [None]:
!ls 

In [4]:

from __future__ import print_function
from PIL import Image
import os
import os.path
import numpy as np
import sys

import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm

import torch
from torch.optim import Adam
from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
from torchvision import models

from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda:0")



In [6]:
#Vgg16 Model
class Vgg(nn.Module):
    def __init__(self, requires_grad=False):
        super(Vgg, self).__init__()
        features = models.vgg16(pretrained=True).features
        #separate 4 blocks of VGG16

        self.f1 = nn.Sequential(*(list(features.children())[0:4]))
        self.f2 = nn.Sequential(*(list(features.children())[4:9]))
        self.f3 = nn.Sequential(*(list(features.children())[9:16]))
        self.f4 = nn.Sequential(*(list(features.children())[16:23]))
        
        if requires_grad == False:
            for p in self.parameters():
                p.requires_grad = False
      

    def forward(self, x):
        x1 = self.f1(x)
        relu1_2 = x1
        x2 = self.f2(x1)
        relu2_2 = x2
        x3 = self.f3(x2)
        relu3_3 = x3
        x4 = self.f4(x3)
        relu4_3 = x4
       
        output = [relu1_2, relu2_2, relu3_3, relu4_3]
        return output
Vgg()

Vgg(
  (f1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
  )
  (f2): Sequential(
    (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): ReLU(inplace=True)
    (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
  )
  (f3): Sequential(
    (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): ReLU(inplace=True)
    (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
  )
  (f4): Sequential(
    (0

In [7]:
#Transformation Network

class ResidualBlock(nn.Module):
    """
    Block for implementation of Residual connection
    
    """

    def __init__(self, c):
        super(ResidualBlock, self).__init__()
        #2 layers of 3x3 convolutional layers
        self.conv1 = nn.Sequential(
          nn.ReflectionPad2d(3//2),
          nn.Conv2d(c, c, kernel_size=3, stride=1),
          nn.InstanceNorm2d(c, affine=True)
        )

        self.conv2 = nn.Sequential(
          nn.ReflectionPad2d(3//2),
          nn.Conv2d(c, c, kernel_size=3, stride=1),
          nn.InstanceNorm2d(c, affine=True)
        )

      
        self.relu = torch.nn.ReLU()


    def forward(self, x):
        res = x

        x1 = self.conv1(x)    
        x1= self.relu(x1)
        x1 = self.conv2(x1)

        x1 = x1 + res
        return x1




class Transformer(nn.Module):
    def __init__(self):
        super(Transformer, self).__init__()
        # conv layers
        self.conv1 =nn.Sequential(  
          nn.ReflectionPad2d(9//2),
          nn.Conv2d(3, 32, kernel_size=9, stride=1),
          nn.InstanceNorm2d(32, affine=True),
          nn.ReLU()    
        )

        #2 upsample layers
        self.conv2 = nn.Sequential(
          nn.ReflectionPad2d(3//2),
          nn.Conv2d(32, 64, kernel_size=3, stride=2),
          nn.InstanceNorm2d(64, affine=True),
          nn.ReLU() 
            
        )

        self.conv3 = nn.Sequential(
          nn.ReflectionPad2d(3//2),
          nn.Conv2d(64, 128, kernel_size=3, stride=2),
          nn.InstanceNorm2d(128, affine=True),
          nn.ReLU() 
            
        )

        # 5 layers of residual block
        self.res = nn.Sequential(
            ResidualBlock(128),
            ResidualBlock(128),
            ResidualBlock(128),
            ResidualBlock(128),
            ResidualBlock(128)
            
        )
        # 2 downsample layers
        self.conv4 = nn.Sequential(
          nn.Upsample(mode='nearest', scale_factor=2.0),
          nn.ReflectionPad2d(3//2),
          nn.Conv2d(128, 64, kernel_size=3, stride=1),
          nn.InstanceNorm2d(64, affine=True),
          nn.ReLU() 
            
        )

        self.conv5 = nn.Sequential(
          nn.Upsample(mode='nearest', scale_factor=2.0),
          nn.ReflectionPad2d(3//2),
          nn.Conv2d(64, 32, kernel_size=3, stride=1),
          nn.InstanceNorm2d(32, affine=True),
          nn.ReLU() 
            
        )

        self.conv6 = nn.Sequential(
          torch.nn.ReflectionPad2d(9//2),
          torch.nn.Conv2d(32, 3, kernel_size=9, stride=1), 
          )

        

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.res(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)
       
        return x


Transformer()

Transformer(
  (conv1): Sequential(
    (0): ReflectionPad2d((4, 4, 4, 4))
    (1): Conv2d(3, 32, kernel_size=(9, 9), stride=(1, 1))
    (2): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (3): ReLU()
  )
  (conv2): Sequential(
    (0): ReflectionPad2d((1, 1, 1, 1))
    (1): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2))
    (2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (3): ReLU()
  )
  (conv3): Sequential(
    (0): ReflectionPad2d((1, 1, 1, 1))
    (1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2))
    (2): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (3): ReLU()
  )
  (res): Sequential(
    (0): ResidualBlock(
      (conv1): Sequential(
        (0): ReflectionPad2d((1, 1, 1, 1))
        (1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
        (2): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      

In [8]:
def normalize(img):
    # normalize the transformed image with mean and std
    im_std = img.new_tensor([0.229, 0.224, 0.225]).view( -1, 1, 1)
    im_mean = img.new_tensor([0.485, 0.456, 0.406]).view( -1, 1, 1)
      
    img = img.div_(255.0)
    v = (img - im_mean)/im_std
    return v


In [9]:
def calculate_gram_matrix(input):
  '''
  A function to calcuate the gram matrix
  '''
  a, c, h, w = input.size()

  feature = input.view(a, c, w*h)
  feature_T = feature.transpose(1, 2)
  gram = feature.bmm(feature_T)
  gram /=  (c*h*w)

  return gram

In [None]:
#Hyperparameters

batch_size = 4
epochs =2
dataset_path = "./train2014"
style_image_path ="./input/Mona_Lisa.jpg"
save_path ="./output"

feature_w = 1e5
style_w = 1e10

np.random.seed(50)
torch.manual_seed(50)


In [12]:
# Model training
def train():   
    # Need to resize to 3x256x256
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(256),
        transforms.ToTensor(),
        transforms.Lambda(lambda x:x.mul(255))
    ])

    style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])


    train_dataset = datasets.ImageFolder(dataset_path, transform= transform)
    train_loader = DataLoader(train_dataset, batch_size= batch_size)

    model = Transformer().to(device)
    vgg = Vgg(requires_grad=False).to(device)

    style = Image.open(style_image_path)
    style = style_transform(style)

    style = style.repeat(batch_size, 1, 1, 1)
    style = style.to(device)
    style = normalize(style)

    style_feature = vgg(style)
    #get gram matrix for style
    style_gram =[]
    for f in style_feature:
        style_gram.append(calculate_gram_matrix(f))
    
    
    
    loss = nn.MSELoss()
    optimizer = Adam(model.parameters(), lr=1e-3)
    

    for e in range(epochs):
        print("start epoch ",e)
        model.train()

        for l, batch in enumerate(tqdm(train_loader)):
            #print("start batch ", batch_id)
            img = batch[0] 
            sz =  len(img)  
           

            optimizer.zero_grad()

            img = img.to(device)
            trans = model(img)

            # Normalize images
            trans = normalize(trans)
            img = normalize(img)

            trans_feature = vgg(trans)
            img_feature = vgg(img)

            #get relu3_3
            feature_loss =loss(trans_feature[2], img_feature[2])

            style_loss = 0.0

            #Calculate gram matrix, sum of relu1_2, relu2_2, relu3_3, relu4_3
            for i in range(len(trans_feature)):
                trans_gram = calculate_gram_matrix(trans_feature[i])
                style_loss += loss(trans_gram, style_gram[i][:sz])
            
            #multiply the weight
            feature_loss *= feature_w
            style_loss *= style_w

            total_loss = feature_loss + style_loss

            total_loss.backward()
            optimizer.step()
        
                
    print("Training done")
          
    return model



In [None]:
#Start training
model = train()

In [None]:
#Image Style Transformation using model

content_path ="./input/bobo.jpg"
save_result_path = "./output.jpg"

content = Image.open(content_path)

transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x:x.mul(255))
    ])

content = transform(content)
content = content.unsqueeze(0)
content= content.to(device)

with torch.no_grad():  
    model.to(device)
    output = model(content)
    output = output.cpu()
    plt.figure()
    img = (output[0].clamp(0, 255).numpy().transpose(1, 2, 0)).astype("uint8")
    plt.imshow(img)
    #save image
    Image.fromarray(img).save(save_result_path)