# Neural Style Transfer

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image
from PIL import Image

import vgg_neural_style_transfer as nst

#### Setting up device

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#### Hyperparameters

#### Inspecting VGG19

In [13]:
model = models.vgg19(pretrained=True).features

print(model)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPoo

#### Train loop

In [14]:
def calculate_losses(original_features: torch.Tensor, style_features: torch.Tensor, generated_features: torch.Tensor):
    # loss functions
        original_loss = 0
        style_loss = 0

        # iterating through each feature layer
        for original_feature, style_feature, generated_feature in zip(original_features, style_features, generated_features):

            # shape of images for certain feature layer
            _, channel, height, width = generated_feature.shape

            # gram matrices for generated and style
            generated_gram_matrix = torch.mm(
                generated_feature.reshape(channel, height*width),
                generated_feature.reshape(channel, height*width).t()
                )

            style_gram_matrix = torch.mm(
                style_feature.reshape(channel, height*width),
                style_feature.reshape(channel, height*width).t()
                )

            # mse
            original_loss += torch.mean((generated_feature-original_feature)**2)
            style_loss += torch.mean((generated_gram_matrix-style_gram_matrix)**2)

        return original_loss, style_loss


def train(model: nn.Module, original_image: torch.Tensor, style_image: torch.Tensor, generated_image: torch.Tensor, epochs: int, learning_rate: float, alpha: float, beta: float):
    """
    Training loop
    """

    # setting up optimizer to optimize generated_image
    optimizer = optim.Adam([generated_image], lr=learning_rate)

    for step in range(epochs):
        # grabbing features in each feature layer for each image
        original_features = model(original_image)
        style_features = model(style_image)
        generated_features = model(generated_image)

        # calculate losses
        original_loss, style_loss = calculate_losses(original_features, style_features, generated_features)

        # counting total loss
        total_loss = alpha * original_loss + beta * style_loss


        # backpropagation and gradient descent
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        # printing loss and saving image
        if step % 500 == 0:
            print(f"Step: {step} Loss: {total_loss}")
            save_image(generated_image, f"images/generated_step{step}.jpg")

#### Model and hyperparameters

In [15]:
feature_layers = ['0', '5', '10', '19', '28']

model = nst.VGG_NST(feature_layers).to(device)

image_size = 512
learning_rate = 0.001
epochs = 6000
alpha = 1
beta = 0.05

#### Image loading function

In [16]:
image_transforms = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor()
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def load_image(image_name, image_transform):
    image = Image.open(image_name)
    image = image_transform(image).unsqueeze(0)

    return image.to(device)

#### Loading images

In [17]:
original_image = load_image("images/original.jpg", image_transforms)
style_image = load_image("images/style_image.jpg", image_transforms)

generated_image = original_image.clone().requires_grad_(True)

#### Training

In [None]:
train(model, original_image, style_image, generated_image, epochs, learning_rate, alpha, beta)