In [None]:
import torch
from torch import optim
from PIL import Image
from torchvision import transforms

#Image I/O


In [None]:
# imagenet mean and std for normalization
VGG_MEAN = [0.485, 0.456, 0.406]
VGG_STD = [0.229, 0.224, 0.225]

def load_image(path, max_size=512):
    img = Image.open(path).convert("RGB")

    # downsampling for maintaining size
    if max(img.size) > max_size:
        scale = max_size / max(img.size)
        new_size = [int(img.width * scale), int(img.height * scale)]
        img = img.resize(new_size, Image.LANCZOS)

    return img

# make sure img is tensor with the right dimension
def preprocess(img: Image.Image, device):
    # convert to tensor
    tfm = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=VGG_MEAN, std=VGG_STD)
    ])

    # add batch dimension [1, 3, H, W]
    x = tfm(img).unsqueeze(0).to(device)   # add one dimension on index 0
    return x

# convert back to viewable image
def deprocess(x: torch.Tensor):
    # match the mean and std with dim of img
    mean = torch.tensor(VGG_MEAN, device=x.device).view(1, 3, 1, 1)
    std = torch.tensor(VGG_STD, device=x.device).view(1, 3, 1, 1)

    # denorm
    y = x * std + mean
    y = torch.clamp(y, 0, 1)
    y = y.squeeze(0).detach().cpu()

    return transforms.ToPILImage()(y)

def save_image(x: torch.Tensor, out_path: str):
    deprocess(x).save(out_path)

#VGG19 Model

In [None]:
import torch.nn as nn
from torchvision import models

# define layers needed from VGG_19
VGG_LAYERS = {
    "conv1_1": 0,
    "conv2_1": 5,
    "conv3_1": 10,
    "conv4_1": 19,
    "conv4_2": 21,    # used for content layer in the original paper
    "conv5_1": 28,
}

class VGG19FeatureExtractor(nn.Module):
    def __init__(self, layer_map=VGG_LAYERS):
        super().__init__()

        # load vgg19, which was pretrained for ImageNet, and discard the fully connected layers
        vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features

        # turn off ReLU inplce to prevent changing the input of conv layers/ replace maxpooling with avgpooling
        for i,m in enumerate(vgg):
            if isinstance(m, nn.ReLU):
                vgg[i] = nn.ReLU(inplace=False)
            if isinstance(m, nn.MaxPool2d):
                vgg[i] = nn.AvgPool2d(kernel_size=m.kernel_size, stride=m.stride, padding=m.padding)

        # turn on the evaluation mode
        self.features = vgg.eval()

        # freeze the model from being trained
        for p in self.features.parameters():
            p.requires_grad = False

        self.layer_map = layer_map

    # cpu or gpu
    @torch.no_grad()
    def device(self):
        return next(self.features.parameters()).device

    def forward(self, x: torch.Tensor):
        """
        x: [B, 3, H, W] in nomrlized VGG space. Returns dict(name: activation)
        """
        feats = {}
        t = x

        for i, layer in enumerate(self.features):
            t = layer(t)
            for name, idx in self.layer_map.items():
                if i == idx:
                    feats[name] = t
        return feats

# Gram Matrix

In [None]:
def gram_matrix(feat: torch.Tensor):
    B, C, H, W = feat.shape
    F = feat.view(B, C, H*W)
    G = F @ F.transpose(1, 2)

    return G

# Loss Function

In [None]:
import torch.nn.functional as F

def content_loss(gen_feats: dict, content_feats: dict, layer: str="conv4_2"):
    F_l = gen_feats[layer]
    P_l = content_feats[layer]

    return 0.5 * F.mse_loss(F_l, P_l, reduction="sum")

def style_loss(gen_feats: dict, style_feats: dict, style_layers: list[str], layer_weights: dict):
    if layer_weights is None:
        layer_weights = {l: 1.0 for l in style_layers}

    total = 0.0

    for layer in style_layers:
        F_l = gen_feats[layer]
        S_l = style_feats[layer]
        _, C, H, W = F_l.shape
        N_l = C
        M_l = H * W

        G_l = gram_matrix(F_l)
        A_l = gram_matrix(S_l)

        E_l = ((G_l - A_l)**2).sum() / (4 * N_l**2 * M_l**2)

        w_l = layer_weights.get(layer, 1.0)
        total = total + w_l * E_l
    return total

# Main Loop

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

content_path = "content_image.jpg"
style_path = "style_image.jpg"

style_layers = ["conv1_1", "conv2_1", "conv3_1", "conv4_1", "conv5_1"]
layer_weights = {l: 0.2 for l in style_layers}

alpha = 1.0 # content weight
beta = 1e5 # style weight

content_img = load_image(content_path)
style_img = load_image(style_path)

content = preprocess(content_img, device=device) # [1, 3, H, W]
style = preprocess(style_img, device=device)

vgg = VGG19FeatureExtractor().to(device)
vgg.eval()

with torch.no_grad():
    content_feats = vgg(content)
    style_feats = vgg(style)

generated = content.clone().requires_grad_(True) # start from content

optimizer = optim.LBFGS([generated])
num_steps = 1000

def closure():
  optimizer.zero_grad()

  gen_feats = vgg(generated)
  c_loss = content_loss(gen_feats, content_feats, layer="conv4_2")
  s_loss = style_loss(gen_feats, style_feats, style_layers, layer_weights)

  loss = alpha * c_loss + beta * s_loss
  loss.backward()

  return loss

step = 0

print("Optimizing...")
for step in range(num_steps):
    loss = optimizer.step(closure)
    if (step + 1) % 50 == 0:
        print(f"Step {step+1}/{num_steps}, loss = {loss.item():.4f}")

out_img = deprocess(generated)
out_img.save("nst_output.png")
display(out_img)