### Based on: https://medium.com/pytorch/pystiche-a-framework-for-neural-style-transfer-1ea6e4825f32

In [3]:
!pip install pystiche

Collecting pystiche
  Downloading pystiche-1.0.1-py3-none-any.whl (67 kB)
[K     |████████████████████████████████| 67 kB 4.4 MB/s 
Installing collected packages: pystiche
Successfully installed pystiche-1.0.1


In [4]:
import torch
import pystiche
from pystiche import demo, enc, loss, ops, optim
from torchvision.utils import save_image

In [None]:
import ssl

# Highly discouraged but otherwise it does not want to work
ssl._create_default_https_context = ssl._create_unverified_context

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
multi_layer_encoder = enc.vgg19_multi_layer_encoder()
multi_layer_encoder

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


  0%|          | 0.00/548M [00:00<?, ?B/s]

VGGMultiLayerEncoder(
  arch=vgg19, framework=torch
  (preprocessing): TorchPreprocessing(
    (0): Normalize(mean=('0.485', '0.456', '0.406'), std=('0.229', '0.224', '0.225'))
  )
  (conv1_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu1_1): ReLU()
  (conv1_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu1_2): ReLU()
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2_1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu2_1): ReLU()
  (conv2_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu2_2): ReLU()
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3_1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu3_1): ReLU()
  (conv3_2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu3_2): ReLU()
  (conv3_3): Conv2d(256, 256, kernel_size=(3, 3

In [7]:
content_layer = "relu4_2"
encoder = multi_layer_encoder.extract_encoder(content_layer)
content_loss = pystiche.loss.FeatureReconstructionLoss(encoder)

In [8]:
style_layers = ("relu1_1", "relu2_1", "relu3_1", "relu4_1", "relu5_1")
style_weight = 1e3

In [11]:
def get_encoding_op(encoder, layer_weight):
    return pystiche.loss.GramLoss(encoder, score_weight=layer_weight)

In [12]:
style_loss = pystiche.loss.MultiLayerEncodingLoss(
    multi_layer_encoder, style_layers, get_encoding_op, score_weight=style_weight,
)

In [13]:
criterion = loss.PerceptualLoss(content_loss, style_loss).to(device)

In [14]:
size = 500
images = demo.images()

In [15]:
content_image = images["bird1"].read(size=size, device=device)
criterion.set_content_image(content_image)

In [16]:
style_image = images["paint"].read(size=size, device=device)
criterion.set_style_image(style_image)

In [17]:
input_image = content_image.clone()
output_image = optim.image_optimization(input_image, criterion, num_steps=500)

Image optimization:   0%|          | 0/500 [00:00<?, ?it/s]

In [18]:
save_image(output_image, 'out.png')

In [19]:
save_image(content_image, 'in.png')