In [25]:
from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from PIL import Image
import matplotlib.pyplot as plt

import torchvision.transforms as transforms
import torchvision.models as models

import copy

from torch.autograd import Variable
import numpy as np

import os
import time

from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms


In [26]:
print(torch.cuda.is_available()) 
device = torch.device( "cuda" if torch.cuda.is_available() else "cpu")
if (torch.cuda.is_available()):
    torch.cuda.empty_cache()

True


In [27]:
"""Image Transform Network Blocks as defined in Perceptual Losses for Real-Time Style Transfer and Super-Resolution"""
"""This part of the code is adapted from https://github.com/dxyang/StyleTransfer"""
# Conv Layer
class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(ConvLayer, self).__init__()
        padding = kernel_size // 2
        self.reflection_pad = nn.ReflectionPad2d(padding)
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride) #, padding)

    def forward(self, x):
        out = self.reflection_pad(x)
        out = self.conv2d(out)
        return out

class UpsampleConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
        super(UpsampleConvLayer, self).__init__()
        self.upsample = upsample
        if upsample:
            self.upsample = nn.Upsample(scale_factor=upsample, mode='nearest')
        reflection_padding = kernel_size // 2
        self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        if self.upsample:
            x = self.upsample(x)
        out = self.reflection_pad(x)
        out = self.conv2d(out)
        return out
    
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in1 = nn.InstanceNorm2d(channels, affine=True)
        self.relu = nn.ReLU()
        self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in2 = nn.InstanceNorm2d(channels, affine=True)

    def forward(self, x):
        residual = x
        out = self.relu(self.in1(self.conv1(x)))
        out = self.in2(self.conv2(out))
        out = out + residual
        out = self.relu(out)
        return out 

# Image Transform Network
class ImageTransformNet(nn.Module):
    def __init__(self):
        super(ImageTransformNet, self).__init__()
        
        # nonlineraity
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

        # encoding layers
        self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
        self.in1_e = nn.InstanceNorm2d(32, affine=True)

        self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
        self.in2_e = nn.InstanceNorm2d(64, affine=True)

        self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
        self.in3_e = nn.InstanceNorm2d(128, affine=True)

        # residual layers
        self.res1 = ResidualBlock(128)
        self.res2 = ResidualBlock(128)
        self.res3 = ResidualBlock(128)
        self.res4 = ResidualBlock(128)
        self.res5 = ResidualBlock(128)

        # decoding layers
        self.deconv3 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2 )
        self.in3_d = nn.InstanceNorm2d(64, affine=True)

        self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2 )
        self.in2_d = nn.InstanceNorm2d(32, affine=True)

        self.deconv1 = UpsampleConvLayer(32, 3, kernel_size=9, stride=1)
        self.in1_d = nn.InstanceNorm2d(3, affine=True)

    def forward(self, x):
        # encode
        y = self.relu(self.in1_e(self.conv1(x)))
        y = self.relu(self.in2_e(self.conv2(y)))
        y = self.relu(self.in3_e(self.conv3(y)))

        # residual layers
        y = self.res1(y)
        y = self.res2(y)
        y = self.res3(y)
        y = self.res4(y)
        y = self.res5(y)

        # decode
        y = self.relu(self.in3_d(self.deconv3(y)))
        y = self.relu(self.in2_d(self.deconv2(y)))
        #y = self.tanh(self.in1_d(self.deconv1(y)))
        y = self.deconv1(y)

        return y

In [28]:
"""propagation_layers = ['relu_1', 'relu_2', 'relu_3', 'relu_4']
class Vgg19(nn.Module):
    def __init__(self):
        super(Vgg19, self).__init__()
        features = models.vgg19(pretrained=True).features
        self.layers = []
        for i in propagation_layers:
            self.layers.append( nn.Sequential() )
        
        checkpoints = [4,9,16,23]
        layer_index = 0
        for i in range(checkpoints[-1]):
            if i in checkpoints:
                layer_index += 1
            self.layers[layer_index].add_module(str(i), features[i])
            
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, image):
        print(self.layers)
        state_vector = []
        inter = self.layers[0](image)
        state_vector.append(inter)
        for i in range(1,len(self.layers)):
            inter = self.layers[i](inter)
            state_vector.append(inter)
        return state_vector"""

"propagation_layers = ['relu_1', 'relu_2', 'relu_3', 'relu_4']\nclass Vgg19(nn.Module):\n    def __init__(self):\n        super(Vgg19, self).__init__()\n        features = models.vgg19(pretrained=True).features\n        self.layers = []\n        for i in propagation_layers:\n            self.layers.append( nn.Sequential() )\n        \n        checkpoints = [4,9,16,23]\n        layer_index = 0\n        for i in range(checkpoints[-1]):\n            if i in checkpoints:\n                layer_index += 1\n            self.layers[layer_index].add_module(str(i), features[i])\n            \n        for param in self.parameters():\n            param.requires_grad = False\n\n    def forward(self, image):\n        print(self.layers)\n        state_vector = []\n        inter = self.layers[0](image)\n        state_vector.append(inter)\n        for i in range(1,len(self.layers)):\n            inter = self.layers[i](inter)\n            state_vector.append(inter)\n        return state_vector"

In [29]:
class Vgg19(nn.Module):
    def __init__(self):
        super(Vgg19, self).__init__()
        features = models.vgg16(pretrained=True).features
        self.to_relu_1_2 = nn.Sequential() 
        self.to_relu_2_2 = nn.Sequential() 
        self.to_relu_3_3 = nn.Sequential()
        self.to_relu_4_3 = nn.Sequential()

        for x in range(4):
            self.to_relu_1_2.add_module(str(x), features[x])
        for x in range(4, 9):
            self.to_relu_2_2.add_module(str(x), features[x])
        for x in range(9, 16):
            self.to_relu_3_3.add_module(str(x), features[x])
        for x in range(16, 23):
            self.to_relu_4_3.add_module(str(x), features[x])
        
        # don't need the gradients, just want the features
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        h = self.to_relu_1_2(x)
        h_relu_1_2 = h
        h = self.to_relu_2_2(h)
        h_relu_2_2 = h
        h = self.to_relu_3_3(h)
        h_relu_3_3 = h
        h = self.to_relu_4_3(h)
        h_relu_4_3 = h
        out = (h_relu_1_2, h_relu_2_2, h_relu_3_3, h_relu_4_3)
        return out

In [30]:
def normalize_imagenet():
    return transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

In [31]:
width = 512 if torch.cuda.is_available() else 256
height = 512  if torch.cuda.is_available() else 256
common_size = 512 if torch.cuda.is_available() else 256

convert = transforms.ToPILImage()

transformer = transforms.Compose([
            transforms.Resize((common_size,common_size)), 
            transforms.CenterCrop(common_size),
            transforms.ToTensor(),             
            normalize_imagenet()     
    ])

def load_image(image_name):
    orig_image = Image.open(image_name)
    return orig_image

def transform( orig_image ):
    image = transformer(orig_image).unsqueeze(0).to(device, torch.float)
    if (image.size())[1] != 3:
        image = torch.cat([image,image,image], dim=1) 
    return image

def display_image( image ):
    display = image.cpu().clone()
    display = display.squeeze(0)
    display = convert(display)
    print(display.size)
    plt.imshow(display)

def save_image(filename, data):
    std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
    mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
    img = data.clone().numpy()
    img = ((img * std + mean).transpose(1, 2, 0)*255.0).clip(0, 255).astype("uint8")
    img = Image.fromarray(img)
    img.save(filename)

In [10]:
def calc_content_loss( y_c_features, y_hat_features):
    loss_mse = torch.nn.MSELoss()
    return loss_mse(y_c_features[1] , y_hat_features[1])

In [11]:
def gram_matrix(image):
    (b, c, h, w) = image.size()
    features = image.view(b, c, w * h)
    features_t = features.transpose(1, 2)
    G = features.bmm(features_t) / (c * h * w)
    return G

In [12]:
def calc_style_loss(y_hat_features, style_gram):
    y_hat_gram = [gram_matrix(fmap) for fmap in y_hat_features]
    
    loss_mse = torch.nn.MSELoss()
    
    style_loss = 0.0
    for j in range(len(style_gram)):
        style_loss += loss_mse(y_hat_gram[j], style_gram[j])
    return style_loss

In [13]:
lr = 1e-3
def Optimizer(params):
    optimizer = optim.Adam(params, lr) #Optimizer for CNN (Try with LBFGS)
    return optimizer

In [32]:
batch_size = 1

dataset_transform = transforms.Compose([
    transforms.Resize((common_size,common_size)),          
    transforms.ToTensor(),                
    normalize_imagenet()     
])

style_transform = transforms.Compose([
    transforms.ToTensor(),
    normalize_imagenet()     
])

def transfer_style(content_image_path, style_image_path, model_path, output,
                       epoch=2, lamda_style=10000, lamda_content=1 ):
    try: 
        style_model = ImageTransformNet().type(torch.cuda.FloatTensor)
        style_model.load_state_dict(torch.load(model_path))
        
        content_image = load_image(content_image_path)
        content_image = transformer(content_image)
        content_image = content_image.unsqueeze(0)
        content_image = Variable(content_image).type(torch.cuda.FloatTensor)

        stylized = style_model(content_image).cpu()
        save_image(output, stylized.data[0])
    except:
        print("No model found, switched to model training")
        image_transformer = ImageTransformNet().type(torch.cuda.FloatTensor)
        optimizer = Optimizer(image_transformer.parameters())
        
        vgg = Vgg19().type(torch.torch.cuda.FloatTensor)

        train_dataset = datasets.ImageFolder(dataset, dataset_transform)
        train_loader = DataLoader(train_dataset, batch_size = batch_size)

        style = load_image(style_image_path)
        style = style_transform(style)
        style = Variable(style.repeat(batch_size, 1, 1, 1)).type(torch.cuda.FloatTensor)

        style_features = vgg(style)
        style_gram = [gram_matrix(fmap) for fmap in style_features]

        for e in range(epoch):
            image_transformer.train()
            for batch_num, (x, label) in enumerate(train_loader):
                optimizer.zero_grad()

                x = Variable(x).type(torch.cuda.FloatTensor)
                y_hat = image_transformer(x)

                y_c_features = vgg(x)
                y_hat_features = vgg(y_hat)
                
                style_loss = calc_style_loss( y_hat_features, style_gram )
                content_loss = calc_content_loss( y_c_features, y_hat_features)

                loss = lamda_style*style_loss + lamda_content*content_loss 
                print(batch_num)
                loss.backward()
                optimizer.step()

        image_transformer.eval()
        image_transformer.cpu()
        
        model_name = "pretrained_model"

        filename = str(model_name) + ".model"
        torch.save(image_transformer.state_dict(), filename)

In [51]:
tic = time.time()
#Run twice or more by changing the content_image
dataset = "./train_coco_big"
model_path = "./models/mosaic.model"
output = "out.png"
style_image_path = "./images/style_images/mosaic.jpeg"
content_image_path = "./images/content_images/uni2.jpeg"
transfer_style(content_image_path, style_image_path, model_path, output)
toc = time.time()
elapsed = toc - tic
print("Elapsed time {} second".format(elapsed))



Elapsed time 0.601315975189209 ms
