In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

from torchvision import models
from torchvision.transforms import v2

import numpy as np
from PIL import Image

In [24]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

### Images preparation stage

In [216]:
image1 = Image.open("image.jpeg").convert('RGB') # A puppy picture # RGB as we send 3 channels to CNN model
image2 = Image.open("vahn_gogh.png").convert('RGB') # one of the great Vahn Gogh's pictures

v2_transforms = v2.Compose([v2.Resize(256),
                            v2.CenterCrop(224),
                            v2.ToImage(), 
                            v2.ToDtype(torch.float32, scale=True),
                            v2.Lambda(lambda img: img.unsqueeze(0))])

image1, image2 = [v2_transforms(img).to(device) for img in (image1, image2)] # tranforming both images into tensors shaped like 1-224-224 where 1 is for batch and next two are for sizes

image1_clone = image1.clone().requires_grad_(True).to(device) # creating image1 clone for following styling

### Custom VGG19 model preparation stage

In [219]:
class MyCustomVGG(nn.Module):
    def __init__(self, layers: list):
        super().__init__()
        self.vgg19 = models.vgg19(weights=models.VGG19_Weights.DEFAULT)
        self.features = self.vgg19.features
        self.features.requires_grad_(False)
        self.requires_grad_(False)
        self.features.eval()  # так как в классе есть dropout, надо включать модель
        self.layers = layers
        
    def forward(self, x):
        results = []
        for ind, layer in enumerate(self.features):
            x = layer(x)
            if ind in self.layers:
                results.append(x.squeeze())
        return results

model = MyCustomVGG([0, 5, 10, 19, 28, 34]).to(device)

In [220]:
image1.size()

torch.Size([1, 3, 224, 224])

### Supporting functions

In [224]:
def content_loss(orig_img, orig_img_clone):
    return torch.mean((orig_img - orig_img_clone) ** 2)

# def gram_matrix(tensor: torch.Tensor):
#     if tensor.dim() == 4:
#         tensor = tensor.squeeze(0)  # если есть размер батча

#     C, H, W = tensor.size()
#     features = tensor.view(C, H * W)
#     gram = torch.mm(features, features.t())

#     return gram / (C * H * W)

def gram_matrix(tensor):
    filters = tensor.size(dim=0)
    g = tensor.view(filters, -1)
    gram = torch.mm(g, g.mT) / g.size(dim=1)

    return gram

def style_loss(clone_style_tensors, image2_grams):
    weights = [1, 0.8, 0.5, 0.3, 0.1]

    loss = 0
    i = 0
    for base, target in zip(clone_style_tensors, image2_grams):
        gram = gram_matrix(base)
        loss += weights[i] * torch.mean((gram - target) ** 2)
        i += 1
    return loss
    

In [226]:
image1_vggcustomed = model(image1) # image1 and image2 processed by the model
image2_vggcustomed = model(image2)

image2_gram = [gram_matrix(x) for x in image2_vggcustomed[:5]]

content_weight = 1
style_weight = 1000
best_loss = -1
epochs = 100
best_img = image1_clone.clone()

optimizer = optim.Adam(params=[image1_clone], lr=0.1)

In [228]:
for epoch in range(epochs):
    results_image1_clone = model(image1_clone)

    cl = content_loss(results_image1_clone[-1], image1_vggcustomed[-1])
    sl = style_loss(results_image1_clone, image2_gram)
    loss = (content_weight * cl) + (style_weight * sl)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    image1_clone.data.clamp_(0, 1)

    if loss < best_loss:
        best_loss = loss
        best_img = image1_clone.clone()

    print(f"Iteration: {epoch}, loss: {loss.item():.4f}")

Iteration: 0, loss: 578.6679
Iteration: 1, loss: 544.0251
Iteration: 2, loss: 378.4812
Iteration: 3, loss: 585.8953
Iteration: 4, loss: 679.1354
Iteration: 5, loss: 533.3723
Iteration: 6, loss: 558.8918
Iteration: 7, loss: 493.2389
Iteration: 8, loss: 458.6534
Iteration: 9, loss: 424.7810
Iteration: 10, loss: 388.7843
Iteration: 11, loss: 361.5681
Iteration: 12, loss: 329.0883
Iteration: 13, loss: 304.3185
Iteration: 14, loss: 280.7676
Iteration: 15, loss: 258.5783
Iteration: 16, loss: 239.1976
Iteration: 17, loss: 220.5872
Iteration: 18, loss: 202.9990
Iteration: 19, loss: 187.7311
Iteration: 20, loss: 174.0888
Iteration: 21, loss: 160.2138
Iteration: 22, loss: 148.2854
Iteration: 23, loss: 137.0153
Iteration: 24, loss: 126.6022
Iteration: 25, loss: 116.9087
Iteration: 26, loss: 108.0686
Iteration: 27, loss: 101.4310
Iteration: 28, loss: 101.4318
Iteration: 29, loss: 132.9100
Iteration: 30, loss: 85.3831
Iteration: 31, loss: 97.3374
Iteration: 32, loss: 112.0142
Iteration: 33, loss: 8

In [230]:
x = best_img.detach().squeeze()
low, hi = torch.amin(x), torch.amax(x)
x = (x - low) / (hi - low) * 255.0
x = x.permute(1, 2, 0)
x = x.cpu().numpy()
x = np.clip(x, 0, 255).astype('uint8')

In [232]:
image = Image.fromarray(x, 'RGB')
image.save('result.jpg')