In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torchvision.models import vgg19, VGG19_Weights
from PIL import Image
import copy

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [3]:
def load_image(img_path, max_size=400, shape=None):
    image = Image.open(img_path).convert('RGB')
    
    size = max_size if max(image.size) > max_size else max(image.size)
    if shape:
        size = shape
        
    transform = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
    ])
    
    image = transform(image).unsqueeze(0)
    return image.to(device)

In [4]:
content = load_image('./img/input_pic.jpg')
style = load_image('./img/style_pic.jpg', shape=content.shape[-2:])

In [5]:
weightsVGG19 = VGG19_Weights.DEFAULT 
vgg = vgg19(weights=weightsVGG19).features.to(device).eval()
for param in vgg.parameters():
    param.requires_grad_(False)

In [6]:
content_layers = ['conv4_1']
style_layers = ['conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1'] 

In [23]:
def get_features(image, mode, layers):
    features = {}
    x = image
    mapping = {
        '0': 'conv1_1',
        '5': 'conv2_1',
        '10': 'conv3_1',
        '19': 'conv4_1',
        '21': 'conv4_2',
        '28': 'conv5_1'
    }
    for name, layer in vgg._modules.items():
        x = layer(x)
        if name in mapping:
            layer_name = mapping[name]
            if layer_name in layers:
                features[layer_name] = x
    return features
        

In [26]:
def gram_matrix(tensor):
    _, c, h, w = tensor.size()
    tensor = tensor.view(c, h*w)
    gram = torch.mm(tensor, tensor.t())
    return gram / (c*h*w)

In [None]:
target = content.clone().requires_grad(True).to(device)