In [None]:
# references:
# https://github.com/jcjohnson/fast-neural-style
# https://github.com/pytorch/examples/tree/master/fast_neural_style
# https://github.com/lengstrom/fast-style-transfer
# "Perceptual Losses for Real-Time Style Transfer and Super-Resolution" by Johnson, Justin and Alahi, Alexandre and Fei-Fei, Li
# https://github.com/pytorch/examples/blob/master/fast_neural_style


In [17]:
import os
import sys
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms

from vgg import Vgg16
from PIL import Image


In [18]:

class TransformerNet(torch.nn.Module):
    # feed forward transformation network 
    def __init__(self, style_num=4):
        super(TransformerNet, self).__init__()
        self.conv1 = ConvLayer(3, 32, kernel_size = 9, stride = 1) 
        self.in1 = batch_InstanceNorm2d(style_num, 32)

        self.conv2 = ConvLayer(32, 64, kernel_size = 3, stride = 2)
        self.in2 = batch_InstanceNorm2d(style_num, 64)

        self.conv3 = ConvLayer(64, 128, kernel_size = 3, stride = 2)
        self.in3 = batch_InstanceNorm2d(style_num, 128)

        self.res1 = ResidualBlock(128)
        self.res2 = ResidualBlock(128)
        self.res3 = ResidualBlock(128)
        self.res4 = ResidualBlock(128)
        self.res5 = ResidualBlock(128)
        self.deconv1 = UpsampleConvLayer(128, 64, kernel_size = 3, stride = 1, upsample = 2)
        self.in4 = batch_InstanceNorm2d(style_num, 64)
        self.deconv2 = UpsampleConvLayer(64, 32, kernel_size = 3, stride = 1, upsample = 2)
        self.in5 = batch_InstanceNorm2d(style_num, 32)
        self.deconv3 = ConvLayer(32, 3, kernel_size = 9, stride = 1)
        self.relu = torch.nn.ReLU()

    def forward(self, X, style_id):
 
        y = self.relu(self.in1(self.conv1(X), style_id))
        y = self.relu(self.in2(self.conv2(y), style_id))
        y = self.relu(self.in3(self.conv3(y), style_id))
        y = self.res1(y)
        y = self.res2(y)
        y = self.res3(y)
        y = self.res4(y)
        y = self.res5(y)
        y = self.relu(self.in4(self.deconv1(y), style_id))
        y = self.relu(self.in5(self.deconv2(y), style_id))
        y = self.deconv3(y) 
        
        return y


In [19]:
class ResidualBlock(torch.nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
        self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        residual = x
        out = self.relu(self.in1(self.conv1(x)))
        out = self.in2(self.conv2(out))
        out = out + residual 
        return out

In [20]:
class batch_InstanceNorm2d(torch.nn.Module):
    def __init__(self, style_num, in_channels):
        super(batch_InstanceNorm2d, self).__init__()
        self.inns = torch.nn.ModuleList([torch.nn.InstanceNorm2d(in_channels, affine=True) for i in range(style_num)])

    def forward(self, x, style_id):
        out = torch.stack([self.inns[style_id[i]](x[i].unsqueeze(0)).squeeze_(0) for i in range(len(style_id))])
        return out

class ConvLayer(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(ConvLayer, self).__init__()
        reflection_padding = kernel_size // 2 # same dimension after padding
        self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride) # remember this dimension

    def forward(self, x):
        out = self.reflection_pad(x)
        out = self.conv2d(out)
        return out

class Upsample(nn.Module):
    def __init__(self,  scale_factor=None, mode='nearest'):
        super(Upsample, self).__init__()
        self.scale_factor = scale_factor
        self.mode = mode
    def forward(self, x):
        return torch.nn.functional.interpolate(x, mode = self.mode, scale_factor=self.scale_factor)



class UpsampleConvLayer(torch.nn.Module):
    # instead of convtanspose2d
    def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
        super(UpsampleConvLayer, self).__init__()
        self.upsample = upsample
        if upsample:
            self.upsample_layer = Upsample(mode='nearest', scale_factor=upsample)
        reflection_padding = kernel_size 
        self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        x_in = x
        if self.upsample:
            x_in = self.upsample_layer(x_in)
        out = self.reflection_pad(x_in)
        out = self.conv2d(out)
        return out
    
class batch_InstanceNorm2d(torch.nn.Module):
    #Conditional Instance Normalization
    def __init__(self, style_num, in_channels):
        super(batch_InstanceNorm2d, self).__init__()
        self.inns = torch.nn.ModuleList([torch.nn.InstanceNorm2d(in_channels, affine=True) for i in range(style_num)])

    def forward(self, x, style_id):
        out = torch.stack([self.inns[style_id[i]](x[i].unsqueeze(0)).squeeze_(0) for i in range(len(style_id))])
        return out
    
class ConvLayer(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(ConvLayer, self).__init__()
        reflection_padding = kernel_size // 2 # same dimension after padding
        self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride) # remember this dimension

    def forward(self, x):
        out = self.reflection_pad(x)
        out = self.conv2d(out)
        return out

In [21]:
def load_image(filename, size=None, scale=None):
    img = Image.open(filename)
    if size is not None:
        img = img.resize((size, size), Image.ANTIALIAS)
    elif scale is not None:
        img = img.resize((int(img.size[0] / scale), int(img.size[1] / scale)), Image.ANTIALIAS)
    return img


def save_image(filename, data):
    img = data.clone().clamp(0, 255).numpy()
    img = img.transpose(1, 2, 0).astype("uint8")
    img = Image.fromarray(img)
    img.save(filename)


def gram_matrix(y):
    (b, ch, h, w) = y.size()
    features = y.view(b, ch, w * h)
    features_t = features.transpose(1, 2) 
    gram = features.bmm(features_t) / (ch * h * w)
    return gram


def normalize_batch(batch):
    mean = batch.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
    std = batch.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
    batch = batch.div_(255.0) 
    return (batch - mean) / std


In [22]:

def stylize():
    device = torch.device("cuda")

    content_image = load_image(options['content_image'], scale=None)
    content_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device)


    with torch.no_grad():
        style_model = TransformerNet(style_num=options['style_num'])
        state_dict = torch.load(options['model'])
        style_model.load_state_dict(state_dict)
        style_model.to(device)
        output = style_model(content_image, style_id = [options['style_id']]).cpu()
    
    save_image(options['name']+str(options['style_id'])+'.jpg', output[0])



In [23]:
dir2 ='images/Final1_jpg/'
options['style_id'] = 3
for filename in os.listdir(dir2):
    options['name'] = 'output_final/Final13/'+filename
    options['content_image']= dir2 + filename
    stylize()

In [24]:
file ='Target1'
dir1 = 'images/'+file+'.jpg'
options['style_id'] = 3
options['name'] = 'output_final/'+file+str(options['style_id'])
options['content_image']= dir1
stylize()

In [38]:

def train():
    device = torch.device("cuda" )
    
    # load data
    np.random.seed(options["seed"])
    torch.manual_seed(options["seed"])

    transform = transforms.Compose([
        transforms.Resize(options["image_size"]), 
        transforms.CenterCrop(options["image_size"]),
        transforms.ToTensor(), 
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(options["dataset"], transform)
    train_loader = DataLoader(train_dataset, batch_size=options["batch_size"], shuffle=True) 

    style_image = [f for f in os.listdir(options["style_image"])]
    style_num = len(style_image)
    
    # train
    transformer = TransformerNet(style_num=style_num).to(device)
    optimizer = Adam(transformer.parameters(), options["lr"])
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose([
        transforms.Resize(options["style_size"]), 
        transforms.CenterCrop(options["style_size"]),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])

    style_batch = []

    for i in range(style_num):
        style = load_image(options["style_image"] + style_image[i], size=options["style_size"])
        style = style_transform(style)
        style_batch.append(style)

    style = torch.stack(style_batch).to(device)

    features_style = vgg(normalize_batch(style))
    gram_style = [gram_matrix(y) for y in features_style]

    for e in range(options["epochs"]):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            
            if n_batch < options["batch_size"]:
                break 
                
            count += n_batch
            optimizer.zero_grad() 

            batch_style_id = [i % style_num for i in range(count-n_batch, count)]
            y = transformer(x.to(device), style_id = batch_style_id)

            y = normalize_batch(y)
            x = normalize_batch(x)

            features_y = vgg(y.to(device))
            features_x = vgg(x.to(device))
            content_loss = options["content_weight"] * mse_loss(features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[batch_style_id, :, :])
            style_loss *= options["style_weight"]

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()


    # save model
    transformer.eval().cpu()
    save_model_filename = options["name"]+"_" + str(options["epochs"]) + ".model"
    save_model_path = os.path.join(options["model_dir"], save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)

In [40]:
options = {'content_image': 'images/content_images_final/person_jumping.jpg',
           'model': 'pytorch_models/epoch_2_Thu_May_16_190329_2019_100000_10000000000.model',
           'output_image': 'output',
           'style_num': 4,
           'style_id': 0,
           'name': 'LSTM5', 
           'epochs':3,
           'batch_size':4,
           'dataset':'images/',
           'style_image':'images/style_images/',
           'model_dir':'pytorch_models/',
           'image_size':64,
           'style_size':64,
           'seed':32,
           'content_weight':1e5,
           'style_weight':1e10,
           'lr':1e-3}

In [None]:
train()