In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torchvision.models import vgg19, VGG19_Weights
from PIL import Image
import copy

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [3]:
def load_image(img_path, max_size=400, shape=None):
    image = Image.open(img_path).convert('RGB')

    size = max_size if max(image.size) > max_size else max(image.size)
    if shape:
        size = shape

    transform = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
    ])

    image = transform(image).unsqueeze(0)
    return image.to(device)

In [7]:
def get_features(image, model, layers):
    features = {}
    x = image
    mapping = {
        '0': 'conv1_1',
        '5': 'conv2_1',
        '10': 'conv3_1',
        '19': 'conv4_1',
        '21': 'conv4_2',
        '28': 'conv5_1'
    }
    for name, layer in model._modules.items():
        x = layer(x)
        if name in mapping:
            layer_name = mapping[name]
            if layer_name in layers:
                features[layer_name] = x
    return features


In [8]:
def gram_matrix(tensor):
    _, c, h, w = tensor.size()
    tensor = tensor.view(c, h*w)
    gram = torch.mm(tensor, tensor.t())
    return gram / (c*h*w)

In [13]:
def img_convert(tensor):
    image = tensor.to('cpu').clone().detach()
    image = image.squeeze(0)
    image = image * torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)
    image = image + torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
    image = image.clamp(0,1)
    return transforms.ToPILImage()(image)

In [31]:
content = load_image('./img/input_pic.jpg', max_size=800) # исходное изображение
style = load_image('./img/style_pic.jpg', shape=content.shape[-2:]) # стиль, который хотим взять

In [5]:
weightsVGG19 = VGG19_Weights.DEFAULT
vgg = vgg19(weights=weightsVGG19).features.to(device).eval()
for param in vgg.parameters():
    param.requires_grad_(False)

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


100%|██████████| 548M/548M [00:05<00:00, 115MB/s]


In [6]:
content_layers = ['conv4_1']
style_layers = ['conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1']

In [35]:
target = content.clone().requires_grad_(True).to(device) # изображение копия контента, которое оптимизируем

In [36]:
aplha = 1 # вес исходника
beta = 1e6 # вес стиля
optimizer = optim.Adam([target], lr=0.003)
epochs = 800

In [37]:
content_features = get_features(content, vgg, content_layers)
style_features = get_features(style, vgg, style_layers)
style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}

In [38]:
for epoch in range(1, epochs+1):
    target_features = get_features(target, vgg, content_layers+style_layers)

    content_loss = 0
    for layer in content_layers:
        content_loss += torch.mean((target_features[layer] - content_features[layer])**2)

    style_loss = 0
    for layer in style_layers:
        target_gram = gram_matrix(target_features[layer])
        style_loss += torch.mean((target_gram - style_grams[layer])**2)

    total_loss = aplha * content_loss + beta * style_loss

    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    if epoch % 50 == 0:
        print(f'Epoch {epoch}/{epochs} | Loss: {total_loss.item():.4f}')


Epoch 50/800 | Loss: 109.0128
Epoch 100/800 | Loss: 63.3153
Epoch 150/800 | Loss: 51.3148
Epoch 200/800 | Loss: 45.5524
Epoch 250/800 | Loss: 41.7797
Epoch 300/800 | Loss: 38.8992
Epoch 350/800 | Loss: 36.4919
Epoch 400/800 | Loss: 34.3602
Epoch 450/800 | Loss: 32.3933
Epoch 500/800 | Loss: 30.5212
Epoch 550/800 | Loss: 28.7081
Epoch 600/800 | Loss: 26.9373
Epoch 650/800 | Loss: 25.2036
Epoch 700/800 | Loss: 23.5117
Epoch 750/800 | Loss: 21.8741
Epoch 800/800 | Loss: 20.3071


In [39]:
result = img_convert(target)
result.save('./img/result.jpg')
result.show()