In [None]:
%matplotlib inline
import warnings
warnings.filterwarnings('ignore', message="The parameter 'pretrained' is deprecated")
warnings.filterwarnings('ignore', message="Arguments other than a weight enum or `None` for 'weights' are deprecated")

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
from torch.autograd import Variable
import matplotlib.pyplot as plt
from PIL import Image
from torchvision.utils import save_image
import scipy.misc


In [None]:
imsize = 512

In [None]:
loader = transforms.Compose([
            transforms.Resize(imsize),
            transforms.ToTensor()
        ])

unloader = transforms.ToPILImage()

In [None]:
def loadImage(name):
    image = Image.open(name)
    image = Variable(loader(image))
    image = image.unsqueeze(0)
    return image

def saveImage(input, path):
    image = input.data.clone().cpu()
    batch_size = image.size(0)
    image = image.view(batch_size, 3, imsize, imsize)
    image = unloader(image)
    scipy.misc.imsave(path, image)
    
def im_convert(tensor):
    ''' Presentar imagen como Tensor'''
    imagen = tensor.to('cpu').clone().detach()
    imagen = imagen.numpy().squeeze()
    imagen = imagen.transpose(1, 2, 0)
    #imagen = imagen * np.array((0.029, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    imagen = imagen.clip(0, 1)
    
    return imagen

def showImages(content, style, output):
    fig, (ax1, ax2, ax3) = plt.subplots(nrows=1, ncols=3, figsize=(20, 10))

    ax1.imshow(im_convert(content))
    ax1.axis('off')
    ax1.set_title('Content Image')

    ax2.imshow(im_convert(style))
    ax2.axis('off')
    ax2.set_title('Style Image')

    ax3.imshow(im_convert(output))
    ax3.axis('off')
    ax3.set_title('Merged Image')

    plt.show()

In [None]:
class GramMatrix(nn.Module):

    def forward(self, input):
        a,b,c,d = input.size()
        features = input.view(a*b, c*d)
        G = torch.mm(features, features.t())
        return G.div(a*b*c*d)

In [None]:
class CNN(object):
    def __init__(self, style, content, pastiche):
        super(CNN, self).__init__()
        
        self.style = style
        self.content = content
        self.pastiche = nn.Parameter(pastiche.data)

        self.content_layers = ['conv_4']
        self.style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
        self.contentWeight = 1
        self.styleWeight = 1000

        self.loss_network = models.vgg19(pretrained=True)

        self.gram = GramMatrix()
        self.loss = nn.MSELoss()
        self.optimizer = optim.LBFGS([self.pastiche])

        self.use_cuda = torch.cuda.is_available()
        if self.use_cuda:
            self.loss_network.cuda()
            self.gram.cuda()
        else:
            self.loss_network.cpu()
            self.gram.cpu()

    def train(self):
        def closure():
            self.optimizer.zero_grad()

            pastiche = self.pastiche.clone()
            pastiche.data.clamp_(0,1)
            content = self.content.clone()
            style = self.style.clone()

            content_loss = 0
            style_loss = 0

            i = 1
            not_inplace = lambda layer: nn.ReLU(inplace=False) if isinstance(layer, nn.ReLU) else layer
            for layer in list(self.loss_network.features):
                layer = not_inplace(layer)
                if self.use_cuda:
                    layer.cuda()
                else:
                    layer.cpu()
                
                pastiche, content, style = layer.forward(pastiche), layer.forward(content), layer.forward(style)

                if isinstance(layer, nn.Conv2d):
                    name = "conv_" +str(i)

                    if name in self.content_layers:
                        content_loss += self.loss(pastiche * self.contentWeight, content.detach() * self.contentWeight)

                    if name in self.style_layers:
                        pastiche_g, style_g = self.gram.forward(pastiche), self.gram.forward(style)
                        style_loss += self.loss(pastiche_g * self.styleWeight, style_g.detach() * self.styleWeight)
                    
                if isinstance(layer, nn.ReLU):
                    i += 1
            
            total_loss = content_loss + style_loss
            total_loss.backward()

            return total_loss
    
        self.optimizer.step(closure)
        return self.pastiche

In [None]:
dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor

In [None]:
style = loadImage('../data/style/composition.jpg')
content = loadImage('../data/content/nalax.jpg')

pastiche = loadImage('../data/content/nalax.jpg').type(dtype)
# This is to make random noise image (randomize pixels)
#pastiche.data = torch.randn(pastiche.data.size()).type(dtype)

epochs = 30

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(nrows=1, ncols=3, figsize = (20, 10))
ax1.imshow(im_convert(content))
ax2.imshow(im_convert(style))
ax3.imshow(im_convert(pastiche))

In [None]:
cnn = CNN(style, content, pastiche)
for i in range(epochs):
    pastiche = cnn.train()

    if i % 10 == 0:
        print("Iteration: %d" % (i))

        path = '../data/out/output/%d.png' % (i)
        pastiche.data.clamp_(0 ,1)
        #print("Pérdida Total: ", cnn.)
        plt.imshow(im_convert(pastiche))
        plt.show()
        
    if i == epochs:        
        save_image(pastiche, path)

In [None]:
showImages(content,style,pastiche)