In [1]:
import torch
import torch.nn as nn

from torchvision import transforms
from torchvision import models

from PIL import Image
import numpy as np

# **Style Transfer**

- **Task**: combines the content of one image with the style of another image using CNN.
- **Method Overview**: Given a content image and a style image, the goal is to generate a target image that minimizes the content difference with the content image and the stye difference with the style image 
- *Content Loss*: minimize the MSE between feature maps from the content image and style image
- *Style Loss*: forward propagate the style image and the target image to the VGGNet, extract the feature maps; To generate a texture that matches the style of the style image, we update the target image by minimizing the MSE between the `Gram` matrxi of the style image and the target image

### **Device**

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### **Image Preprocessing**

In [7]:
# VGGNet was trained on ImageNet where images are normalized by mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].
# We use the same normalization statistics here.
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225))
])

### **Load Content & Style Image**

In [5]:
def load_image(image_path, transforms=None, max_size=None, shape=None):
    image = Image.open(image_path)
    if max_size:
        scale = max_size / max(image.size)
        size = np.array(image.size) * scale
        image = image.reszie(size.astype(int), Image.ANTIALIAS)
        
    if shape:
        image = image.resize(shape, Image.LANCZOS)
        
    if transform:
        image = transforms(image).unsqueeze(0)
        
    return image.to(device)

### **Model**

In [6]:
class VGGNet(nn.Module):
    def __init__(self):
        super(VGGNet, self).__init__()
        self.select = ['0', '5', '10', '19', '28']
        self.vgg = models.vgg19(pretrained=True).features
        
    def forward(self, x):
        features = []
        for name, layer in self.vgg._modules.items():
            x = layer(x)
            if name in self.select:
                features.append(x)
        return features

In [None]:
vgg = VGGNet().to(device).eval()

### **Hyper-parameter**

In [None]:
content = '' # content image
style = '' # style image
max_size = 400 #
total_step = 2000
log_step = 10
sample_step = 500
style_weight = 100.0
lr = 0.003

### **Optimizer**

In [None]:
optimizer = torch.optim.Adam([target], lr=config.lr. betas=[0.5, 0.999])

### **Training**

In [8]:
for step in range(config.total_step):
    # Extract multiple conv feature vectors
    target_features = vgg(target)
    content_features = vgg(content)
    style_features = vgg(style)
    
    content_loss = 0
    style_loss = 0
    for f1, f2, f3 in zip(target_features, content_features, style_features):
        # Compute content loss with target and content images
        content_loss += torch.mean((f1 - f2)**2) # MSE
        
        # Reshape conv feature maps
        _, c, h, w = f1.size()
        f1 = f1.view(c, h * w)
        f3 = f3.view(c, h * w)
        
        # Compute style loss with target and style images
        style_loss += torch.mean((f1 - f3)**2) / (c * h * w)
        
    # Compute total loss
    loss = content_loss + config.style_weight * style_loss
    optimizer.zero_grad()
    loss.backward()
    optimizer.backward()
    optimizer.step()
    
    if (step + 1) % config.log_step == 0:
        print ('Step [{}/{}], Content Loss: {:.4f}, Style Loss: {:.4f}' 
                   .format(step+1, config.total_step, content_loss.item(), style_loss.item()))
    
    if (step+1) % config.sample_step == 0:
            # Save the generated image
            denorm = transforms.Normalize((-2.12, -2.04, -1.80), (4.37, 4.46, 4.44))
            img = target.clone().squeeze()
            img = denorm(img).clamp_(0, 1)
            torchvision.utils.save_image(img, 'output-{}.png'.format(step+1))

NameError: name 'config' is not defined

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument('--content', type=str, default='png/content.png')
parser.add_argument('--style', type=str, default='png/style.png')
parser.add_argument('--max_size', type=int, default=400)
parser.add_argument('--total_step', type=int, default=2000)
parser.add_argument('--log_step', type=int, default=10)
parser.add_argument('--sample_step', type=int, default=500)
parser.add_argument('--style_weight', type=float, default=100)
parser.add_argument('--lr', type=float, default=0.003)
config = parser.parse_args()
print(config)
main(config)