<a href="https://colab.research.google.com/github/VKarpick/augmented_style_transfer/blob/main/AugmentTransfer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Overview

This workbook is for pairing image segmentation with neural style transfer.  It allows for different styles to be applied to different elements within images.

The style transfer work comes predominantly from [here](https://nextjournal.com/gkoehler/pytorch-neural-style-transfer) while the segmentation work is courtesy [this link](
https://www.learnopencv.com/pytorch-for-beginners-semantic-segmentation-using-torchvision/).

If you intend to use it, make sure to use the GPU (Runtime -> Change runtime type).

In [None]:
%matplotlib inline

import PIL
import matplotlib.pyplot as plt
import numpy as np
from torchvision import models, transforms
import torch
from google.colab import files
from io import BytesIO

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Pre-trained Models

Deeplabv3_ResNet101 is used for the image segmentation and VGG19 for style transfer.  The categories Deeplab allows for and the colors chosen to represent them are as follows:

*   background = not quite black
*   aeroplane = green
*   bicycle = lime
*   bird = cyan
*   boat = navy
*   bottle = blue
*   bus = yellow
*   car = red
*   cat = orange
*   chair = beige
*   cow = grey
*   diningtable = maroon
*   dog = teal
*   horse = brown
*   motorbike = pink
*   person = purple
*   pottedplant = olive
*   sheep = white
*   sofa = apricot
*   train = mint
*   tvmoniter = lavender

In [None]:
dlab = models.segmentation.deeplabv3_resnet101(pretrained=True).to(device).eval()

In [None]:
vgg = models.vgg19(pretrained=True).features

# convert max pooling layers to average pooling
for name, layer in vgg.named_children():
  if isinstance(layer, torch.nn.MaxPool2d):
    vgg[int(name)] = torch.nn.AvgPool2d(kernel_size=2, stride=2)

# freeze it so no gradients are computed
for param in vgg.parameters():
  param.requires_grad_(False)
    
# move to gpu
vgg.to(device).eval()

# Image Processing

These are all the functions required to manipulate the images.

In [None]:
def open_image(image_file):
  """Finds and opens the first uploaded image."""
  
  return PIL.Image.open(BytesIO(list(image_file.values())[0])).convert("RGB")

In [None]:
def transform_image(image, shape=(224, 224)):
  """Converts an image into a tensor of the given shape."""

  if shape == (0, 0):
    shape = (image.height, image.width)

  image_transforms = transforms.Compose([
    transforms.Resize(shape),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225])
  ])

  return image_transforms(image)[:3, :, :].unsqueeze(0)

In [None]:
def tensor_to_np(tensor):
  """Reverses a tensor of an image into a numpy array that matplotlib can display."""
  
  image = tensor.to("cpu").clone().detach()
  image = image.numpy().squeeze(axis=0)
  image = image.transpose(1, 2, 0)
  image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
  image = image.clip(0, 1)    # color data from 0-1 rather than 0-255

  return image

In [None]:
def tensor_to_pil(tensor):
  """Converts a tensor to a PIL image"""

  #PIL doesn't support float range from images, need to convert from 0-1 to 0-255
  return PIL.Image.fromarray((tensor_to_np(tensor) * 255).astype(np.uint8))

In [None]:
def label_to_color(image, label_nos=list(range(21))):
  """Converts colors of the image based on segmentation."""

  label_colors = np.array(
      [(1, 1, 1),  # 0) background = not quite black
       (60, 180, 75),  # 1) aeroplane = green
       (210, 245, 60), # 2) bicycle = lime
       (70, 240, 240),  # 3) bird = cyan
       (0, 0, 128),  # 4) boat = navy
       (0, 130, 200),  # 5) bottle = blue
       (255, 225, 25),  # 6) bus = yellow
       (230, 25, 75),  # 7) car = red
       (245, 130, 48),  # 8) cat = orange
       (255, 250, 200),  # 9) chair = beige
       (128, 128, 128),  # 10) cow = grey
       (128, 0, 0),  # 11) diningtable = maroon
       (0, 128, 128),  # 12) dog = teal
       (170, 110, 40),  # 13) horse = brown
       (250, 190, 212),  # 14) motorbike = pink
       (145, 30, 180),  # 15) person = purple
       (128, 128, 0),  # 16) pottedplant = olive
       (255, 255, 255),  # 17) sheep = white
       (255, 215, 180),  # 18) sofa = apricot
       (170, 255, 195),  # 19) train = mint
       (220, 190, 255),  # 20) tvmoniter = lavender
      ])
  
  r = torch.zeros_like(image)
  g = torch.zeros_like(image)
  b = torch.zeros_like(image)

  for label_no in label_nos:
    idx = image == label_no
    r[idx] = label_colors[label_no, 0]
    g[idx] = label_colors[label_no, 1]
    b[idx] = label_colors[label_no, 2]

  rgb = torch.stack([r, g, b], axis=0)

  return rgb

In [None]:
def segment(image, label_nos=list(range(21))):
  """Converts an image into the segmented representation of the image."""
  
  segmented_image = dlab(image)["out"]
  segmented_image = torch.argmax(segmented_image.squeeze(), dim=0)
  segmented_image = label_to_color(segmented_image, label_nos)
  segmented_image = torch.div(segmented_image, 255.0)
  segmented_image = transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])(segmented_image)
  segmented_image = segmented_image.unsqueeze(0)

  return segmented_image

In [None]:
def convert_to_original_colors(image):
  """Converts an images colors to remain consistent with the content image."""

  content_channels = list(tensor_to_pil(content_image).convert('YCbCr').split())
  image_channels = list(image.convert('YCbCr').split())
  content_channels[0] = image_channels[0]
  converted_image = PIL.Image.merge('YCbCr', content_channels).convert('RGB')
  return transform_image(converted_image, content_image.shape[2:]).to(device)

In [None]:
def update_final_image(object_image, background_image):
  """Combines two copies of an image, one with the updated objects and one with the updated background."""

  segmented_image = content_segment
  
  # de-normalize segmented image to set black to (0, 0, 0)
  segmented_image = transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
                                         std=[1/0.229, 1/0.224, 1/0.225])(content_segment.squeeze(0))

  # find sections of the segmented content image that aren't black
  # black = background to be update, anything else = object to be updated
  where_segment = (segmented_image > 0)

  # combine the images
  image = object_image.squeeze(0) * where_segment + background_image.squeeze(0) * (torch.logical_not(where_segment))
  
  return image.unsqueeze(0)

# Image Loading

For all of these, if multiple images are selected at once, only one will actually be used.  To change images, these cells will need to be re-run.

All but the segmented image are required, even if one of the style images isn't going to be used.

The first upload is the content image.  This is the image that the segmentation and style transfer will be applied to.

In [None]:
content_image_file = files.upload()

The next upload is for when Deeplab doesn't segment the image as desired.  It allows for an upload to be uploaded for use as the segmented image.  To use it, click the checkbox and upload the image.

For creating a segmented image, any portions of the content image that aren't to be considered objects should be blacked out.  Everything else can be left as is.

In [None]:
use_custom_segment = True #@param {type: "boolean"}

if use_custom_segment:
  custom_segment_file = files.upload()
else:
  custom_segment_file = None

The second upload is the object style image.  This is the image containing the style that will be applied to any objects found in the segmentation.  For example, when taking an image of a person in which one style will be applied to the person and another to the background, this is the image whose style will be applied to the person.

In [None]:
object_style_image_file = files.upload()

The final upload is the background style image.  This is the image containing the style that will be applied to the background.

In [None]:
background_style_image_file = files.upload()

# Image Parameters

All images will be reshaped to the same size.  The default has been set at 224.  It's generally recommended to use larger dimensions as they tend to work better for style transfer, but keep in mind that sizes that are too large will overtax the GPU.

In [None]:
height =  277#@param {type: "number"}
width =  474#@param {type: "number"}

These checkboxes decide which objects will be detected by the segmentation.  Any category selected will have the same style applied to it.

If applying style transfer to the entire image, all can be checked and the weights (in a later section) will have to be set accordingly.

If the desired effect is for different styles to different objects (ie one style for a plane, a different style for a bird), this can't be done simultaneously, but can be accomplished with multiple passes and correctly tuned weights, if the objects are segmented on one at a time.

In [None]:
background = True #@param {type: "boolean"}
aeroplane = True #@param {type: "boolean"}
bicycle = True #@param {type: "boolean"}
bird = True #@param {type: "boolean"}
boat = True #@param {type: "boolean"}
bottle = True #@param {type: "boolean"}
bus = True #@param {type: "boolean"}
car = True #@param {type: "boolean"}
cat = True #@param {type: "boolean"}
chair = True #@param {type: "boolean"}
cow = True #@param {type: "boolean"}
diningtable = True #@param {type: "boolean"}
dog = True #@param {type: "boolean"}
horse = True #@param {type: "boolean"}
motorbike = True #@param {type: "boolean"}
person = True #@param {type: "boolean"}
pottedplant = True #@param {type: "boolean"}
sheep = True #@param {type: "boolean"}
sofa = True #@param {type: "boolean"}
train = True #@param {type: "boolean"}
tvmoniter = True #@param {type: "boolean"}

This cell loads and displays the images.

In [None]:
labels = (aeroplane, bicycle, bird, boat, bottle, bus, car, cat, chair, cow, 
          diningtable, dog, horse, motorbike, person, pottedplant, sheep, 
          sofa, train, tvmoniter)

shape = (height, width)
content_objects = [i + 1 for i, label in enumerate(labels) if label]

content_image = transform_image(open_image(content_image_file), shape).to(device)
if custom_segment_file is None:
  content_segment = segment(content_image, content_objects)
else:
  content_segment = transform_image(open_image(custom_segment_file), shape).to(device)
object_style_image = transform_image(open_image(object_style_image_file), shape).to(device)
background_style_image = transform_image(open_image(background_style_image_file), shape).to(device)

fig, axarr = plt.subplots(1, 4, figsize=(15, 15))
axarr[0].title.set_text("Content Image")
axarr[1].title.set_text("Segmented Image")
axarr[2].title.set_text("Object Style Image")
axarr[3].title.set_text("Background Style Image")
for column, image in enumerate((content_image, content_segment, object_style_image, background_style_image)):
  axarr[column].imshow(tensor_to_pil(image))
  axarr[column].axis("off")

# Features and Grams

Style transfer requires capturing the "features" of an image.  An example of a feature could be any horizontal line in the image.  The features present in an image are found using the pre-trained VGG19 model.

In [None]:
layers = {
  "0": "conv1_1",
  "2": "conv1_2",
  "5": "conv2_1",
  "7": "conv2_2",
  "10": "conv3_1",
  "12": "conv3_2",
  "14": "conv3_3",
  "16": "conv3_4",
  "19": "conv4_1",
  "21": "conv4_2",
  "23": "conv4_3",
  "25": "conv4_4",
  "28": "conv5_1",
  "30": "conv5_2",
  "32": "conv5_3",
  "34": "conv5_4",
  }

def get_features(image, feature_layers):
  features = {}

  x = image
  for name, layer in vgg._modules.items():
    x = layer(x)
    if name in layers and layers[name] in feature_layers:
      features[layers[name]] = x

  return features

Style representations are obtained by measuring the correlation between different feature map responses of a given layer. The dot product of two vectors can be seen as how similar two vectors are to each other - the more similar they are, the lesser the angle between them. In style transfer, this can be done by flattening the feature map's spatial dimensions at each depth and computing its dot product. The result is the Gram Matrix:

$$G_{ij}^l = \sum_{k}^{} F_{ik}^l F_{jk}^l$$

Where $F_{ij}^l$ is the activation of the ith filter at position j in layer l.

In [None]:
def gram_matrix(tensor):
  _, depth, height, width = tensor.size()    # don't need the batch size
  tensor = tensor.view(depth, height * width)    # flatten spatial dimensions
  gram = torch.mm(tensor, tensor.t())    # matrix multiplication of tensor and its transpose

  return gram

# Target Image and Optimizer

The target image is the result of the style transfer.  It can be composed of random white noise, but better results seem to be obtained by setting it to the content image.  In this case, there are two target images to allow for the two different types of styles.

In [None]:
object_target = content_image.clone().requires_grad_(True).to(device)
background_target = content_image.clone().requires_grad_(True).to(device)

Adam and L-BFGS are the two common types of optimizers used in style transfer.  The Adam optimizer is faster and allows for smaller updates to the target image so that is what's provided here.  To use L-BFGS, all that should be required is changing torch.optim.Adam to torch.optim.LBFGS in the two lines below.

In [None]:
object_learning_rate = 1e-2 #@param {type: "number"}
background_learning_rate = 1e-2 #@param {type: "number"}

object_optimizer = torch.optim.Adam([object_target], lr=object_learning_rate)
background_optimizer = torch.optim.Adam([background_target], lr=background_learning_rate)

# Calculating Loss

## Content Loss

Ensures the activations of higher layers are similar between the content image and the generated image.$${\cal L}_{content}(\vec{p},\vec{x},l) = \frac{1}{2} \sum_{i,j}(F_{ij}^l - P_{ij}^l)^2$$Where $\vec{p}$ and $\vec{x}$ are the original image and the image that is generated and $P^l$ and $F^l$ their respective feature representation in layer $l$.

Convolutional feature maps are generally a good representation of an input image's features. They capture spatial information without containing the style information. Therefore, mean squared difference between the target and content features is used.

In [None]:
def compute_content_loss(image):
  image_features = get_features(image, content_layer)
  return torch.mean((image_features[content_layer] - content_features[content_layer])**2)

## Style Loss

Ensures the correlation of activations in all layers are similar between the style image and the generated image.

**Contribution of layer $l$ to total loss**
$$E_l = \frac{1}{4N_l^2M_l^2} \sum_{i,j}(G_{i,j}^l - A_{i,j}^l)^2$$
Where $\vec{a}$ and $\vec{x}$ are the original image and the image that is generated and $A^l$ and $G^l$ their respective style representation in layer $l$. $N_l$ is the number of feature maps and $M_l$ is the height * width of the the feature map.

**Total style loss**
$${\cal L}_{style}(\vec{a},\vec{x}) = \sum_{l=0}^L w_l E_l$$
Where $w_l$ are the weighting factors of the contribution of each layer of the total loss.

Style loss is the mean squared difference between the gram matrix of the input and the gram matrix of the style image.

In [None]:
def compute_style_loss(image, style_features, style_weights, style_grams):
  style_loss = 0
  image_features = get_features(image, style_weights)

  for layer in style_weights:
    image_feature = image_features[layer]
    image_gram = gram_matrix(image_feature)
    _, depth, height, width = image_feature.shape
    
    style_gram = style_grams[layer]
    layer_style_loss = style_weights[layer] * torch.mean((image_gram - style_gram)**2)
    
    style_loss += layer_style_loss / (depth * height * width)
      
  return style_loss

## Total Variation Loss

Not mentioned in the original paper is a factor for total variation loss. This measures how much noise is in the images and can be used to smooth the image. There are different calculations for it; in this case, the sum of the means of the absolute differences between adjacent pixels is used.

In [None]:
def compute_total_variation_loss(image):
  total_variation_loss = (torch.mean(torch.abs(image[:,:,:,:-1] - image[:,:,:,1:]))
                          + torch.mean(torch.abs(image[:,:,:-1,:] - image[:,:,1:,:])))
  
  return total_variation_loss

# Style Transfer Loop

The style transfer loop works by applying style transfer to each of the two target images (one for the objects, one for the background) and keeping a separate image that is a combination of the two.

In [None]:
loss_tracker = {
    "content (object)": 0.0,
    "style (object)": 0.0,
    "total variation (object)": 0.0,
    "content (background)": 0.0,
    "style (background)": 0.0,
    "total variation (background)": 0.0,
}

def style_transfer_loop(total_iterations, display_iterations=0):
  for k in loss_tracker.keys():
    loss_tracker[k] = 0.0

  if display_iterations == 0:
    display_iterations = total_iterations
      
  for iteration in range(1, total_iterations + 1):
    def closure(is_object):
      global loss_tracker

      object_optimizer.zero_grad()
      background_optimizer.zero_grad()
      
      total_loss = 0
            
      if is_object:
        if object_content_weight != 0:
          object_content_loss = compute_content_loss(object_target) * object_content_weight
          loss_tracker["content (object)"] = object_content_loss
          total_loss += object_content_loss

        if object_style_weight != 0:
          object_style_loss = compute_style_loss(object_target, object_features, object_weights, object_grams) * object_style_weight
          loss_tracker["style (object)"] = object_style_loss
          total_loss += object_style_loss

        if object_total_variation_weight != 0:
          object_total_variation_loss = compute_total_variation_loss(object_target) * object_total_variation_weight
          loss_tracker["total variation (object)"] = object_total_variation_loss
          total_loss += object_total_variation_loss

      else:
        if background_content_weight != 0:
          background_content_loss = compute_content_loss(background_target) * background_content_weight
          loss_tracker["content (background)"] = background_content_loss
          total_loss += background_content_loss

        if background_style_weight != 0:
          background_style_loss = compute_style_loss(background_target, background_features, background_weights, background_grams) * background_style_weight
          loss_tracker["style (background)"] = background_style_loss
          total_loss += background_style_loss

        if background_total_variation_weight != 0:
          background_total_variation_loss = compute_total_variation_loss(background_target) * background_total_variation_weight
          loss_tracker["total variation (background)"] = background_total_variation_loss
          total_loss += background_total_variation_loss
            
      total_loss.backward()

      return total_loss
        
    def object_closure():
      return closure(True)

    def background_closure():
      return closure(False)

    if object_content_weight or object_style_weight or object_total_variation_weight:
      object_optimizer.step(object_closure)
    if background_content_weight or background_style_weight or background_total_variation_weight:
      background_optimizer.step(background_closure)
  
    transfer_object = convert_to_original_colors(tensor_to_pil(object_target)) if is_converting_object_colors else object_target
    transfer_background = convert_to_original_colors(tensor_to_pil(background_target)) if is_converting_background_colors else background_target
    transfer_image = tensor_to_pil(update_final_image(transfer_object, transfer_background))

    if iteration % display_iterations == 0 or iteration == total_iterations:
      print("Iteration: ", iteration)
      print("Content (object): {}, Style (object): {}, TV (object): {},\nContent (background): {}, Style (background): {}, TV (background): {}"
        .format(loss_tracker["content (object)"],
                loss_tracker["style (object)"],
                loss_tracker["total variation (object)"],
                loss_tracker["content (background)"],
                loss_tracker["style (background)"],
                loss_tracker["total variation (background)"]
                ))
      plt.imshow(transfer_image)
      plt.axis("off")
      plt.show()

  return transfer_image

# Weight Parameters

In the logical flow of the notebook, these should be higher up.  They were left until the end to simplify updating them before running the style transfer loop when making multiple passes.

When applying style transfer to an entire image or to only the objects of an image, the background weights should all be set to 0.

In [None]:
#@markdown Object Weights
object_content_weight = 1e3 #@param {type: "number"}
object_style_weight = 1e2 #@param {type: "number"}
object_total_variation_weight = 1e-1 #@param {type: "number"}

#@markdown Background Weights
background_content_weight =  0#@param {type: "number"}
background_style_weight =  0#@param {type: "number"}
background_total_variation_weight =  0#@param {type: "number"}

Which layer of VGG will be used to represent the content image's features.  The names of all possible layers that can be included can be found in the Features and Grams section.

In [None]:
content_layer = "conv4_2"  #@param {type: "string"}

These are meant to be tuneable parameters.  Using Colab's parameter settings makes a mess of things, so they're left as a code cell, but they are meant to adjusted.  They determine how much each layer contributes to the style of the image.

Weighting earlier layers more (eg conv1_1) results in larger style artifacts in the generated image. Weighting later layers more emphasizes smaller features.

In [None]:
object_weights = {
    "conv1_1": 1e4,
    "conv2_1": 1e3,
    "conv3_1": 1e2,
    "conv4_1": 1e1,
    "conv5_1": 1e1,
}

background_weights = {
    "conv1_1": 1e4,
    "conv2_1": 1e3,
    "conv3_1": 1e2,
    "conv4_1": 1e1,
    "conv5_1": 1e1,
}

It may be preferable to keep the colors consistent with the original content image.  These can be toggled to do so for the objects and/or the background of the image.

In [None]:
is_converting_object_colors = False #@param {type: "boolean"}
is_converting_background_colors = False #@param {type: "boolean"}

# Running

Finally, this is where the actual transfer happens.  The two parameters are how many times to run through the style transfer loop and on which of those iterations the current image should be displayed.  The first cell only needs to be run once unless something above it has been changed.

In [None]:
content_features = get_features(content_image, content_layer)
object_features = get_features(object_style_image, object_weights)
object_grams = {layer: gram_matrix(object_features[layer]) for layer in object_features}
background_features = get_features(background_style_image, background_weights)
background_grams = {layer: gram_matrix(background_features[layer]) for layer in background_features}

In [None]:
total_iterations = 500 #@param {type: "number"}
display_iterations = 20 #@param {type: "number"}

In [None]:
final_image = style_transfer_loop(total_iterations, display_iterations)

# Saving

In [None]:
save_file_name = "generated_image.png"
final_image.save(save_file_name)
files.download(save_file_name)