In [None]:
import torch
from torchvision import models, transforms
import torch.optim as optimizers
from PIL import Image
import numpy
import matplotlib.pyplot as pyplot
from IPython import display
import numpy as np

In [None]:
#VGG-19 besteht aus zwei Teilen: 1. Convolutions+Pooling 2. Classifieres
# .features beschrenkt das Model auf den 1. Teil (also nur die Convolutions)
vgg_model = models.vgg19(weights='DEFAULT').features

#freezing parameters -> nur target soll verändert werden
for parameter in vgg_model.parameters():
    parameter.requires_grad_(False) #only optimize target image, not the other parameters

torch_device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu") #use gpu if possible

vgg_model.to(torch_device)

def load_image(path_to_image, shape=None):
    loaded_image = Image.open(path_to_image).convert("RGB")

    # Bigger pictures will slow down the training process
    MAXIMUM_SIZE = 500
    if max(loaded_image.size) > MAXIMUM_SIZE:
        size = (MAXIMUM_SIZE, MAXIMUM_SIZE)
    else:
        size = loaded_image.size

    if shape is not None:
        size = shape #.size gibt eh schon ein Tupel zurück

    image_transformations = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor(),
        transforms.Normalize((0.495, 0.455, 0.405),
                             (0.255, 0.220, 0.230))
    ])

    loaded_image = image_transformations(loaded_image)[:3, :, :].unsqueeze(0)

    return loaded_image

In [None]:
original_image = load_image('uploads/contentimage.jpg').to(torch_device) #bilder zum torch.device schicken
style_image = load_image("uploads/styleimage.jpg", shape=original_image.shape[-2:]).to(torch_device) #Style Bild soll genau die gleiche Shape haben, wie das Orginalbild

#visualize the image after tranformation
def display_image(tensor_image):
    image_to_show = tensor_image.to("cpu").clone().detach() #Kopie erzeugen

    image_to_show = image_to_show.numpy().squeeze()  #undo unsqueeze()
    image_to_show = image_to_show.transpose(1, 2, 0)
    image_to_show = image_to_show * numpy.array((0.255, 0.220, 0.230)) + numpy.array((0.495, 0.455, 0.405)) #unnormalize the image
    image_to_show = image_to_show.clip(0, 1) #numpy.clip -> clips the edges

    return image_to_show

In [None]:
#extract the key features of the images
def extract_features(image, model, layers = None):
    if layers is None:
        layers = {
            "0": "conv1_1",
            "5": "conv2_1",
            "10": "conv3_1",
            "19": "conv4_1",
            "25": "conv4_2",
            "28": "conv5_1"
        }

    input1 = image
    features = {}

    for name, layer in model._modules.items():
        input1 = layer(input1)

        if name in layers:
            features[layers[name]] = input1

    return features

def calculate_gram_matrix(tensor_image):
    batch_size, depth, height, width = tensor_image.size()

    tensor_image = tensor_image.view(depth, -1)

    gram_matrix = torch.mm(tensor_image, tensor_image.t()) #multiply by the transposed image

    return gram_matrix

def compute_total_variation_loss(Y_hat):
    return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +
                  torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())

In [None]:
original_features = extract_features(original_image, vgg_model)
style_features = extract_features(style_image, vgg_model)

style_gram_matrices = {layer: calculate_gram_matrix(style_features[layer]) for layer in style_features}#calculate gram_matrix for every layer

altered_image = original_image.clone().requires_grad_(True).to(torch_device)

In [None]:
style_weights = {"conv1_1": 1.,
                 "conv2_1": 0.7,
                 "conv3_1": 0.4,
                 "conv4_1": 0.2,
                 "conv5_1": 0.1}

optimizer = optimizers.Adam([altered_image])
original_image_weight = 1
style_image_weight = 10000
total_variation_weight = 10

In [None]:
NUMBER_OF_EPOCHS = 10000 #can change this

total_loss_for_plt = []  # Initialize an empty list to store total losses
style_loss_for_plt = []  # Initialize an empty list to store style losses
content_loss_for_plt = []  # Initialize an empty list to store content losses

for index in range(1, NUMBER_OF_EPOCHS+1):
    altered_features = extract_features(altered_image, vgg_model) #extract features at every epoch

    style_loss = 0

    original_image_loss = torch.mean((altered_features["conv4_2"] - original_features["conv4_2"]) **2) #mean squared error

    for layer in style_weights:
        altered_feature = altered_features[layer] #get the feature for every single weight

        altered_gram_matrix = calculate_gram_matrix(altered_feature)

        style_matrix = style_gram_matrices[layer]

        layer_style_loss = torch.mean((altered_gram_matrix - style_matrix) ** 2) * style_weights[layer] #mean squared error * layer_weight

        _, depth, height, width = altered_feature.shape

        style_loss += layer_style_loss / (depth * height * width)
        
        total_variation_loss = compute_total_variation_loss(altered_feature) * total_variation_weight

    total_loss = original_image_loss * original_image_weight + style_loss * style_image_weight

    optimizer.zero_grad()

    total_loss.backward()

    optimizer.step()

    if index % 1000 == 0:
        pyplot.imshow(display_image(altered_image))
        if index != NUMBER_OF_EPOCHS:
            pyplot.title("Epoche: " + str(index))
        pyplot.axis('off')
        pyplot.savefig("Epochenbild.png")
        if index == NUMBER_OF_EPOCHS:
            pyplot.savefig("Ausgabebild.png")
        pyplot.show()