In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

from src.image_transformations import add_noise, normalize_batch
from src.common import Common
from src.plotting_images import plot_losses
import scipy.ndimage
import scipy

In [None]:
IMG_DIM = (1200, 1600)
LOW_IMG_DIM = (400, 534)

In [None]:
# Content image
input_img = np.asarray(Image.open('data/mya_img.jpg').resize(IMG_DIM))

# Style image
style_img = np.asarray(Image.open('data/anime.jpg').resize(IMG_DIM))

print(input_img.shape)
print(style_img.shape)

In [None]:
plt.imshow(input_img)
plt.show()

In [None]:
plt.imshow(style_img)
plt.show()

In [None]:
# Using coarse-to-fine image stylization
low_input_img = np.asarray(Image.fromarray(input_img).resize(LOW_IMG_DIM))
low_style_img = np.asarray(Image.fromarray(style_img).resize(LOW_IMG_DIM))

In [None]:
plt.imshow(low_input_img)
plt.show()

In [None]:
plt.imshow(low_style_img)
plt.show()

In [None]:
mse = torch.nn.MSELoss()

In [None]:
class HookModule(nn.Module):
    def __init__(self, model):
        super(HookModule, self).__init__()
        self.model = model
        self.style_hooks = {}
        self.content_hooks = {}
        self.style_layers = ['vgg_19_conv1_conv1_1_Conv2D', 'vgg_19_conv2_conv2_1_Conv2D', 'vgg_19_conv3_conv3_1_Conv2D', 'vgg_19_conv4_conv4_1_Conv2D', 'vgg_19_conv5_conv5_1_Conv2D']
        self.content_layers = ['vgg_19_conv4_conv4_3_Conv2D']
        self.content = []
        self.style = []
    
    def forward(self, x):
        self.reinit_hooks()
        self.content = []
        self.style = []
        self.model(x)
        with torch.no_grad():
            for k, v in self.style_hooks.items():
                self.style_hooks[k].remove()
            for k, v in self.content_hooks.items():
                self.content_hooks[k].remove()
        
        return self.content, self.style
    
    
    def reinit_hooks(self):
        for name, module in self.model.named_modules():
            if name in self.style_layers:
                self.style_hooks[name] = module.register_forward_hook(self.style_hook)
            
            if name in self.content_layers:
                self.content_hooks[name] = module.register_forward_hook(self.content_hook)
    
    def style_hook(self, module, input, output):
        self.style.append(output)
        
    def content_hook(self, module, input, output):
        self.content.append(output)

In [None]:
def denormalize_batch(batch):
    vgg_means = [103.939, 116.779, 123.68]
    ret = torch.zeros(*batch.size())
    ret[:, 0, :, :] = batch[:, 0, :, :] + vgg_means[0]
    ret[:, 1, :, :] = batch[:, 1, :, :] + vgg_means[1]
    ret[:, 2, :, :] = batch[:, 2, :, :] + vgg_means[2]
    return ret

In [None]:
def plot_img(ten):
    img = ten.detach().numpy()
    img = img[0].transpose(1, 2, 0).astype('uint8')
    plt.figure(figsize=(10, 5))
    plt.imshow(img)
    plt.show()

In [None]:
def preprocess(img):
    img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
    img = torch.from_numpy(np.copy(img)).float()
    return img

In [None]:
Common.forward_vgg.features

In [None]:
def compute_gram(matrix):
    '''
    Computes the gram matrix
    '''
    batches, channels, height, width = matrix.size()
    matrix = matrix.view(channels, height * width)
    return (1 / (channels * height * width)) * torch.mm(matrix, matrix.t())

In [None]:
def content_cost(input, target):
    # First normalize both the input and target (preprocess for VGG16)
    input_norm = normalize_batch(input)
    target_norm = normalize_batch(target)

    input_layers = Common.forward_vgg(input_norm, [26])
    target_layers = Common.forward_vgg(target_norm, [26])
    
    accumulated_loss = 0
    for layer in range(len(input_layers)):
        batch, channels, height, width = input_layers[layer].size()
        accumulated_loss = accumulated_loss + mse(input_layers[layer].view(channels, -1),
                                                   target_layers[layer].view(channels, -1))
    
    return accumulated_loss

In [None]:
def style_cost(input, target):
    # First normalize both the input and target (preprocess for VGG16)
    input_norm = normalize_batch(input)
    target_norm = normalize_batch(target)

    input_layers = Common.forward_vgg(input_norm, [3, 8, 17, 26, 35])
    target_layers = Common.forward_vgg(target_norm, [3, 8, 17, 26, 35])
    
    # layer weights
    #layer_weights = [1.5, 1.5, 0.55, 0.33, 0.22, 0.11]
    layer_weights = [0.2, 0.2, 0.2, 0.5, 0.5]
    # The accumulated losses for the style
    accumulated_loss = 0
    
    for layer in range(len(input_layers)):
        batch, channels, height, width = input_layers[layer].size()
        accumulated_loss = accumulated_loss + layer_weights[layer] * mse(compute_gram(input_layers[layer]),
                                                                         compute_gram(target_layers[layer]))
    
    return accumulated_loss

In [None]:
def total_variation_cost(input):
    tvloss = (
        torch.sum(torch.abs(input[:, :, :, :-1] - input[:, :, :, 1:])) + 
        torch.sum(torch.abs(input[:, :, :-1, :] - input[:, :, 1:, :]))
    )
    return tvloss

In [None]:
def total_cost(input, targets):    
    # Extract content and style images
    content, style = targets
    
    REG_CONTENT = 1.0e5
    REG_STYLE = 4e12
    REG_TV = 3e-5
    
    # Get the content, style and tv variation losses
    closs = content_cost(input, content) * REG_CONTENT
    sloss = style_cost(input, style) * REG_STYLE
    tvloss = total_variation_cost(input) * REG_TV
        
    # Add it to the running list of losses
    Common.content_losses.append(closs)
    Common.style_losses.append(sloss)
    Common.tv_losses.append(tvloss)
    
    print('****************************')
    print('Content Loss: {}'.format(closs.item()))
    print('Style Loss: {}'.format(sloss.item()))
    print('Total Variation Loss: {}'.format(tvloss.item()))
    
    return closs + sloss + tvloss

In [None]:
def upsample(img):
    init_img = Image.fromarray(img[0].transpose(1, 2, 0).astype('uint8'))
    init_img = np.asarray(init_img.resize(IMG_DIM, resample=0))
    return preprocess(init_img)

In [None]:
def save_img(img, file_name):
    img = img[0].transpose(1, 2, 0)
    img = img.astype('uint8')
    img = Image.fromarray(img)
    img.save('generated_images/' + file_name + '.jpg')

In [None]:
class PrecomputedStyle(torch.nn.Module):
    def __init__(self, style):
        super(PrecomputedStyle, self).__init__()
        style = normalize_batch(style)
        self.vgg = Common.forward_vgg(style, [3, 8, 15, 22])
        self.precomputed = []
        for x in self.vgg:
            self.precomputed.append(compute_gram(x))
            
    def forward(self):
        ret = []
        for x in self.precomputed:
            ret.append(torch.clone(x))
        return ret

In [None]:
# Reshape the images
low_style_img_ten = preprocess(low_style_img)
low_input_img_ten = preprocess(low_input_img)
low_content_img_ten = preprocess(low_input_img)

In [None]:
# Reshape the images
style_img_ten = preprocess(style_img)
input_img_ten = preprocess(input_img)
content_img_ten = preprocess(input_img)

In [None]:
# Make sure the sizes are right
print(style_img_ten.size())
print(content_img_ten.size())
print(input_img_ten.size())

In [None]:
#low_input_img_ten = torch.ones(3, 534, 400).mul(130).unsqueeze(0)

In [None]:
# Make sure the sizes are right
print(low_style_img_ten.size())
print(low_content_img_ten.size())
print(low_input_img_ten.size())

In [None]:
low_input_img_ten.requires_grad = True

In [None]:
def train(init_img, content_img, style_img, opt):
    for epoch in range(5):
        for batch in range(100):
            # Skip what we've already done
            if epoch == 0 and batch < 0:
                continue

            # Zero the gradients
            opt.zero_grad()

            # Compute loss
            loss = total_cost(init_img, [content_img, style_img])

            # Backprop
            loss.backward()

            # Apply gradients
            opt.step()

            # Make sure the values are not more than 255 or less than 0
            init_img.data.clamp_(0, 255)

            # Every 20 batches, show the loss graphs and the image so far
            if (batch % 20 == 19):
                #plot_losses()
                plot_img(init_img)
                plt.show()

            print("Epoch: {} Training Batch: {}".format(epoch + 1, batch + 1), "Loss: {:f}".format(loss))
            print('****************************')

In [None]:
opt = optim.Adam([low_input_img_ten], lr=2.0)
train(low_input_img_ten, low_content_img_ten, low_style_img_ten, opt)

In [None]:
save_img(low_input_img_ten.detach().numpy(), 'mya_anime_low_res')

In [None]:
init_img = upsample(low_input_img_ten.detach().numpy()).float()

In [None]:
init_img.requires_grad = True

In [None]:
opt = optim.Adam([init_img], lr=2.0)
train(init_img, content_img_ten, style_img_ten, opt)

In [None]:
save_img(init_img.detach().numpy(), 'mya_anime_high_res')

In [None]:
plot_img(low_style_img_ten)