# WK14

## Style Transfer

### Classification + Classification

- Train one CNN for style
- Train another CNN for content
- Drop classification layers, and use last layer of the CNN to get dense feature representation of images
- Given $2$ images (`style` and `content`), get their dense feature representations (`style-f` and `content-f`) by running them through the corresponding CNN
- Change the pixels of the `content` image to decrease its difference in relation to the `style-f` and `content-f` representations

#### Code:
- https://pytorch.org/tutorials/advanced/neural_style_tutorial.html

#### Explanation:
- https://github.com/Adi-iitd/AI-Art/

In [None]:
!wget -q https://github.com/DM-GY-9103-2024F-H/9103-utils/raw/main/src/image_utils.py
!wget -q https://github.com/DM-GY-9103-2024F-H/WK14/raw/main/WK14_utils.py

!wget -P ./data/image -q https://pytorch.org/tutorials/_static/img/neural-style/picasso.jpg
!wget -P ./data/image -q https://pytorch.org/tutorials/_static/img/neural-style/dancing.jpg

In [None]:
import torch

from torch import nn, Tensor
from torch.nn import functional as F

from torchvision.models import resnet34, ResNet34_Weights
from torchvision.models import vgg19, VGG19_Weights
from torchvision.transforms import v2

from image_utils import open_image

from WK14_utils import count_parameters

In [None]:
content_img = open_image("./data/image/dancing.jpg")
style_img = open_image("./data/image/picasso.jpg")

display(content_img)
display(style_img)

In [None]:
loader_transform = v2.Compose([
  v2.Resize(512),
  v2.ToImage(),
  v2.ConvertImageDtype(torch.float)
])

content_t = loader_transform(content_img).unsqueeze(0)
style_t = loader_transform(style_img).unsqueeze(0)

print(content_t.shape)
print(style_t.shape)

display(v2.ToPILImage()(content_t.squeeze()))
display(v2.ToPILImage()(style_t.squeeze()))

In [None]:
mdevice = "cuda" if torch.cuda.is_available() else "cpu"

image_net_mean = [0.485, 0.456, 0.406]
image_net_std = [0.229, 0.224, 0.225]

class Normalization(nn.Module):
  def __init__(self, mean=image_net_mean, std=image_net_std):
    super().__init__()
    self.mean = Tensor(mean).reshape(-1, 1, 1).to(mdevice)
    self.std = Tensor(std).reshape(-1, 1, 1).to(mdevice)

  def forward(self, input):
    return (input - self.mean) / self.std

In [None]:
class ContentLoss(nn.Module):
  def __init__(self, target):
    super().__init__()
    # we "detach" the target content from the computation tree
    # since it"s a fixed value and we don"t need its cost/slope
    self.target = target.detach()

  def forward(self, input):
    self.loss = F.mse_loss(input, self.target)
    return input

In [None]:
class StyleLoss(nn.Module):
    @staticmethod
    def gram_matrix(input):
      a, b, c, d = input.size()  # a=batch size(=1)
      # b=number of feature maps
      # (c,d)=dimensions of a feature map (N=c*d)
      features = input.view(a * b, c * d)
      G = torch.mm(features, features.t())

      # "normalize" the values of the gram matrix
      return G.div(a * b * c * d)

    def __init__(self, target_feature):
        super().__init__()
        self.target = StyleLoss.gram_matrix(target_feature).detach()

    def forward(self, input):
        G = StyleLoss.gram_matrix(input)
        self.loss = F.mse_loss(G, self.target)
        return input

In [None]:
model_resnet = resnet34(weights=ResNet34_Weights.DEFAULT)

print(count_parameters(model_resnet))
display(model_resnet)

In [None]:
model_vgg = vgg19(weights=VGG19_Weights.DEFAULT).features.eval()

print(count_parameters(model_vgg))
display(model_vgg)

In [None]:
default_content_layers = ["conv_4"]
default_style_layers = ["conv_1", "conv_2", "conv_3", "conv_4", "conv_5"]

def create_model(cnn, content, style, content_layers=default_content_layers, style_layers=default_style_layers):
  content_losses = []
  style_losses = []

  model = nn.Sequential(Normalization())

  i = 0
  # iterate over cnn and copy specific layers
  for layer in cnn.children():
    if isinstance(layer, nn.Conv2d):
      i += 1
      name = f"conv_{i}"
    elif isinstance(layer, nn.ReLU):
      name = f"relu_{i}"
      layer = nn.ReLU(inplace=False)
    elif isinstance(layer, nn.MaxPool2d):
      name = f"pool_{i}"
    elif isinstance(layer, nn.BatchNorm2d):
      name = f"bn_{i}"
    else:
      raise RuntimeError(f"Unrecognized layer: {layer.__class__.__name__}")

    model.add_module(name, layer)

    # Add Loss Layers
    if name in content_layers:
      target = model(content).detach()
      content_loss = ContentLoss(target)
      model.add_module(f"content_loss_{i}", content_loss)
      content_losses.append(content_loss)

    if name in style_layers:
      target_feature = model(style).detach()
      style_loss = StyleLoss(target_feature)
      model.add_module(f"style_loss_{i}", style_loss)
      style_losses.append(style_loss)

  # Iterate backwards and detect position j of last content/style loss layer
  for j in range(len(model) - 1, -1, -1):
    if isinstance(model[j], ContentLoss) or isinstance(model[j], StyleLoss):
      break
  # trim off model to end at last content or style loss
  model = model[:(j + 1)]

  return model, content_losses, style_losses

In [None]:
mdevice = "cuda" if torch.cuda.is_available() else "cpu"

model_vgg = model_vgg.to(mdevice)
content_t = content_t.to(mdevice)
style_t = style_t.to(mdevice)

model, content_losses, style_losses = create_model(model_vgg, content_t, style_t)
model.eval()
model.requires_grad_(False)
model = model.to(mdevice)

input_img = content_t.clone().contiguous().to(mdevice)
input_img.requires_grad_(True)
optim = torch.optim.LBFGS([input_img], lr=0.1)

model(input_img).shape

In [None]:
style_weight = 100000
content_weight = 1

for e in range(100):
  def closure():
    with torch.no_grad():
      input_img.clamp_(0, 1)

    optim.zero_grad()
    model(input_img)

    content_score = 0
    style_score = 0
    
    for cl in content_losses:
      content_score += cl.loss
    for sl in style_losses:
      style_score += sl.loss

    content_score *= content_weight  
    style_score *= style_weight

    loss = style_score + content_score
    loss.backward()

    return style_score + content_score

  score = optim.step(closure)
  if e % 10 == 9:
    print(f"Epoch: {e} Score: {score.item():.4f}")

In [None]:
with torch.no_grad():
  output_img = input_img.squeeze().to("cpu").clamp(0,1)
  display(v2.ToPILImage()(output_img))

## Possible Next Steps

- Add neighboring pixel difference penalty to loss
- Experiment with different combinations of layers in the loss function