In [1]:
from collections import namedtuple
from skimage import io
from PIL import Image
import time
import os
import re

import torch
from torchvision import models
from torchvision import transforms
from torch.optim import Adam
from torchvision import datasets
from torch.utils.data import DataLoader

In [2]:
class Vgg16(torch.nn.Module):
    def __init__(self,requires_grad=False):
        super(Vgg16,self).__init__()
        vgg_features = models.vgg16(pretrained=True).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        for x in range(4):
            self.slice1.add_module(str(x),vgg_features[x])
        for x in range(4, 9):
            self.slice2.add_module(str(x),vgg_features[x])
        for x in range(9, 16):
            self.slice3.add_module(str(x), vgg_features[x])
        for x in range(16, 23):
            self.slice4.add_module(str(x), vgg_features[x])
            
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False
                
    def forward(self, X):
        h = self.slice1(X)
        h_relu1_2 = h
        h = self.slice2(h)
        h_relu2_2 = h
        h = self.slice3(h)
        h_relu3_3 = h
        h = self.slice4(h)
        h_relu4_3 = h
        vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3'])
        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3)
        return out

In [3]:
def load_image(filename, size=None, scale=None):
    img = Image.open(filename)
    if size is not None:
        img = img.resize((size, size), Image.ANTIALIAS)
    elif scale is not None:
        img = img.resize((int(img.size[0] / scale), int(img.size[1] / scale)), Image.ANTIALIAS)
    return img

def save_image(filename, data):
    img = data.clone().clamp(0, 255).numpy()
    img = img.transpose(1, 2, 0).astype("uint8")
    img = Image.fromarray(img)
    img.save(filename)

def gram_matrix(y):
    (b, ch, h, w) = y.size()
    features = y.view(b, ch, w * h)
    features_t = features.transpose(1, 2)
    gram = features.bmm(features_t) / (ch * h * w)
    return gram

def normalize_batch(batch):
    # normalize using imagenet mean and std
    mean = batch.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
    std = batch.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
    batch = batch.div_(255.0)
    return (batch - mean) / std

def styleGram(style_image_path,style_transform,vgg):
    style = load_image(style_image_path, size=512)
    style = style_transform(style)
    style = style.repeat(4, 1, 1, 1).to("cuda")
    print(vgg.training)
    with torch.no_grad():
        features_style = vgg(normalize_batch(style))
    gram_style = [gram_matrix(y) for y in features_style] #style layer no X batch size
    return gram_style

def avgStyleGram(style_image_fold,style_transform,vgg):
    train_dataset = datasets.ImageFolder("d:/Styles/", style_transform)
    style_imgs = [x[0].unsqueeze(0).cuda() for x in train_dataset]
    with torch.no_grad():
        style_fs = [vgg(normalize_batch(style)) for style in style_imgs]
        
    grams = [] #images*style_layers
    for f in style_fs:
        tmp = [None]*4
        for k in range(4):
            tmp[k] = gram_matrix(f[k])
        grams.append(tmp)
            
    gram_style = [] #style_layers * batch_size
    for k in range(4):
        val = []
        for g in grams:
            val.append(g[k])
        gram_style.append(torch.cat(val).mean(0).repeat(4,1,1))

    return gram_style

In [4]:
class ConvLayer(torch.nn.Module):
    def __init__(self,in_channels, out_channels, kernel_size, stride):
        super(ConvLayer,self).__init__()
        reflection_padding = kernel_size // 2
        self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
        
    def forward(self,x):
        out = self.reflection_pad(x)
        out = self.conv2d(out)
        return out
    
class ResidualBlock(torch.nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
        self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        residual = x
        out = self.relu(self.in1(self.conv1(x)))
        out = self.in2(self.conv2(out))
        out = out + residual
        return out
        
class UpsampleConvLayer(torch.nn.Module):
    
    def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
        super(UpsampleConvLayer, self).__init__()
        self.upsample = upsample
        if upsample:
            self.upsample_layer = torch.nn.Upsample(mode='nearest', scale_factor=upsample)
        reflection_padding = kernel_size // 2
        self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        x_in = x
        if self.upsample:
            x_in = self.upsample_layer(x_in)
        out = self.reflection_pad(x_in)
        out = self.conv2d(out)
        return out

In [5]:
class TransformerNet(torch.nn.Module):
    def __init__(self):
        super(TransformerNet, self).__init__()
        self.conv1 = ConvLayer(3,32,kernel_size=9,stride=1)
        self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
        self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
        self.in2 = torch.nn.InstanceNorm2d(64, affine=True)
        self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
        self.in3 = torch.nn.InstanceNorm2d(128, affine=True)
        
        self.res1 = ResidualBlock(128)
        self.res2 = ResidualBlock(128)
        self.res3 = ResidualBlock(128)
        self.res4 = ResidualBlock(128)
        self.res5 = ResidualBlock(128)
        
        self.deconv1 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2)
        self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
        self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2)
        self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
        self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1)
        # Non-linearities
        self.relu = torch.nn.ReLU()
        
    def forward(self, X):
        y = self.relu(self.in1(self.conv1(X)))
        y = self.relu(self.in2(self.conv2(y)))
        y = self.relu(self.in3(self.conv3(y)))
        y = self.res1(y)
        y = self.res2(y)
        y = self.res3(y)
        y = self.res4(y)
        y = self.res5(y)
        y = self.relu(self.in4(self.deconv1(y)))
        y = self.relu(self.in5(self.deconv2(y)))
        y = self.deconv3(y)
        return y


In [None]:
def train(style_image_path,dataset_path,style_name="Caligraphy"):
    device = torch.device("cuda")
    
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(256),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    
    train_dataset = datasets.ImageFolder(dataset_path, transform)
    train_loader = DataLoader(train_dataset, batch_size=4)
    
    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(),1e-3)
    mse_loss = torch.nn.MSELoss()
    
    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    
    gram_style = avgStyleGram(style_image_path,style_transform,vgg) #style layer no X batch size
    
    if not style_name:
        style_name = os.path.split(style_image_path)[1]
        style_name = style_name.split('.')[0]
    print("Training for",style_name)
    for e in range(1):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        
        for batch_id, (x, _) in enumerate(train_loader):
            
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()
            
            x = x.to(device)
            y = transformer(x)
            
            y = normalize_batch(y)
            x = normalize_batch(x)
            
            features_y = vgg(y)
            features_x = vgg(x)
            
            content_loss = 1e5 * mse_loss(features_y.relu2_2, features_x.relu2_2)
            
            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= 1e10
            
            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()
            
            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()
            
            if (batch_id + 1) % 1000 == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                                  agg_content_loss / (batch_id + 1),
                                  agg_style_loss / (batch_id + 1),
                                  (agg_content_loss + agg_style_loss) / (batch_id + 1)
                )
                print(mesg)
                
            if (batch_id + 1) % 3000 == 0:
                transformer.eval().cpu()
                
                ckpt_model_filename = style_name + "_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join("d:/FastStyle/checkpts/", ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()
                
    transformer.eval().cpu()
    save_model_filename =  style_name + "_Final.pth"
    save_model_path = os.path.join("d:/FastStyle/checkpts/", save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)

In [8]:
def stylize(content_image,output_image,model,lol=None):
    device = torch.device("cuda")

    content_image = load_image(content_image)
    content_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x:x.repeat(3,1,1) if x.size(0)==1 else x),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device)
    with torch.no_grad():
        style_model = TransformerNet()
        state_dict = torch.load(model)
        # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
        for k in list(state_dict.keys()):
            if re.search(r'in\d+\.running_(mean|var)$', k):
                del state_dict[k]
        style_model.load_state_dict(state_dict)
        style_model.to(device)
        output = style_model(content_image).cpu()
    
    save_image(output_image, output[0])

In [None]:
train("d:/Styles/","d:/COCO17/")

In [11]:
stylize("d:/Images/dancing.jpg","d:/out.jpg","d:/FastStyle/checkpts/Caligraphy_27000.pth")

