In [None]:
from __future__ import division
from torchvision import models
from torchvision import transforms
from PIL import Image
import argparse
import torch
import torchvision
import torch.nn as nn
import numpy as np

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def load_image(image_path, transform=None, max_size=None, shape=None):
    image = Image.open(image_path)
    
    if max_size:
        scale = max_size / max(image.size)
        size = np.array(image.size) * scale
        image = image.resize(size.astype(int), Image.ANTIALIAS)
    
    if shape:
        image = image.resize(shape, Image.LANCZOS)
    
    if transform:
        image = transform(image).unsqueeze(0)
    
    return image.to(device)

In [None]:
class VGGNet(nn.Module):
    def __init__(self):
        super(VGGNet, self).__init__()
        self.select = ['0', '5', '10', '19', '28'] 
        self.vgg = models.vgg19(pretrained=True).features
        
    def forward(self, x):
        features = []
        for name, layer in self.vgg._modules.items():
            if name in self.select:
                x = layer(x)
                features.append(x)
        return features

In [None]:
def main(content, max_size, style, lr, total_step, content_weight, style_weight, log_step, sample_step):
    
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize(mean=(0.485, 0.456, 0.406), 
                                                         std=(0.229, 0.224, 0.225))])
    print("Transformer loaded...")
    
    content = load_image(content, transform, max_size=max_size)
    style = load_image(style, transform, shape=[content.size(2), content.size(3)])
    print("Image loaded...")
    
    target = content.clone().requires_grad_(True)
    # target = torch.randn(content.data.size(), device=device).requires_grad_(True)
    
    optimizer = torch.optim.Adam([target], lr=lr, betas=[0.5, 0.999])
    # optimizer = torch.optim.SGD([target], lr=lr)
    
    vgg = VGGNet().to(device).eval()
    print("Net loaded...")
    for step in range(total_step):
        target_features = vgg(target)
        content_features = vgg(content)
        style_features = vgg(style)

        style_loss = 0
        content_loss = 0
        for f1, f2, f3 in zip(target_features, content_features, style_features):
            content_loss += torch.mean((f1 - f2)**2)

            _, c, h, w = f1.size()
            f1 = f1.view(c, h * w)
            f3 = f3.view(c, h * w)

            f1 = torch.mm(f1, f1.t())
            f3 = torch.mm(f3, f3.t())

            style_loss += torch.mean((f1 - f3)**2) / (c * h * w) 
        
        loss = content_weight * content_loss + style_weight * style_loss 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (step+1) % log_step == 0:
            print ('Step [{.3f}/{:.3f}], Content Loss: {:.4f}, Style Loss: {:.4f}' 
                   .format(step+1, total_step, content_loss.item(), style_loss.item()))

        if (step+1) % sample_step == 0:
            denorm = transforms.Normalize((-2.12, -2.04, -1.80), (4.37, 4.46, 4.44))
            img = target.clone().squeeze()
            img = denorm(img).clamp_(0, 1)
            torchvision.utils.save_image(img, 'output-{:04d}.png'.format(step+1))

In [None]:
main(content='stata.jpg', 
     max_size=256, 
     style='udnie.jpg',
     lr=1e-3, 
     total_step=200, 
     content_weight=1,
     style_weight=1e5, 
     log_step=20, 
     sample_step=100)

In [1]:
import imageio, os
filelist = []  
root = os.getcwd()
pathr = os.path.join(root, "rex_shout")
files = os.listdir(pathr)
for f in files:  
    if(os.path.isfile(pathr + '/' + f)):
        if (os.path.splitext(f)[1] == ".png"):
            filelist.append(f)
ff = sorted(filelist)
images = []
for f in ff:
    images.append(imageio.imread(os.path.join(pathr, f)))
imageio.mimsave(os.path.join(pathr, 'res.gif'), images, duration=0.2)