In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import os
from PIL import Image
import matplotlib.pyplot as plt

import numpy as np

import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torchvision.datasets as dset

In [16]:
project_dir = "gdrive/My Drive/Colab Notebooks/gatys/"
model_dir = project_dir
style_dir = project_dir + "style_images/"
content_dir = project_dir + "content_images/"

In [10]:
USE_GPU = False
dtype = torch.float32

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

# Constant to control how frequently we print train loss
print_every = 100

print('using device:', device)

using device: cpu


The architecture used by Gatys et al. is a subset of VGG-19. Namely, 5 "convolutional layers" (each of which is really a "sandwiching" of consecutive convolutional-relu pairs) joined by a pooling layer. The authors note that better performance was achieved with average pooling, compared to max pooling. 

In [11]:
class VGG(nn.Module):
    def __init__(self, content_acts, style_acts):
        super().__init__()
        self.conv1_1 = nn.Conv2d(in_channels=3, out_channels=64, padding=1, kernel_size=3)
        self.conv1_2 = nn.Conv2d(in_channels=64, out_channels=64, padding=1, kernel_size=3)
        self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)

        self.conv2_1 = nn.Conv2d(in_channels=64, out_channels=128, padding=1, kernel_size=3)
        self.conv2_2 = nn.Conv2d(in_channels=128, out_channels=128, padding=1, kernel_size=3)
        self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)

        self.conv3_1 = nn.Conv2d(in_channels=128, out_channels=256, padding=1, kernel_size=3)
        self.conv3_2 = nn.Conv2d(in_channels=256, out_channels=256, padding=1, kernel_size=3)
        self.conv3_3 = nn.Conv2d(in_channels=256, out_channels=256, padding=1, kernel_size=3)
        self.conv3_4 = nn.Conv2d(in_channels=256, out_channels=256, padding=1, kernel_size=3)
        self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2)

        self.conv4_1 = nn.Conv2d(in_channels=256, out_channels=512, padding=1, kernel_size=3)
        self.conv4_2 = nn.Conv2d(in_channels=512, out_channels=512, padding=1, kernel_size=3)
        self.conv4_3 = nn.Conv2d(in_channels=512, out_channels=512, padding=1, kernel_size=3)
        self.conv4_4 = nn.Conv2d(in_channels=512, out_channels=512, padding=1, kernel_size=3)
        self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2)
        
        self.conv5_1 = nn.Conv2d(in_channels=512, out_channels=512, padding=1, kernel_size=3)
        self.conv5_2 = nn.Conv2d(in_channels=512, out_channels=512, padding=1, kernel_size=3)
        self.conv5_3 = nn.Conv2d(in_channels=512, out_channels=512, padding=1, kernel_size=3)
        self.conv5_4 = nn.Conv2d(in_channels=512, out_channels=512, padding=1, kernel_size=3)

        self.content_acts = content_acts
        self.style_acts = style_acts
    
    def forward(self, x):
        out = {}  # cache all layers from forward pass

        out["r11"] = F.relu(self.conv1_1(x))
        out["r12"] = F.relu(self.conv1_2(out["r11"]))
        out["p1"] = self.pool1(out["r12"])  # note pooling is done to preserve shape

        out["r21"] = F.relu(self.conv2_1(out["p1"]))
        out["r22"] = F.relu(self.conv2_2(out["r21"]))
        out["p2"] = self.pool2(out["r22"])

        out["r31"] = F.relu(self.conv3_1(out["p2"]))
        out["r32"] = F.relu(self.conv3_1(out["r31"]))
        out["r33"] = F.relu(self.conv3_1(out["r32"]))
        out["r34"] = F.relu(self.conv3_1(out["r33"]))
        out["p3"] = self.pool3(out["r34"])

        out["r41"] = F.relu(self.conv4_1(out["p3"]))
        out["r42"] = F.relu(self.conv4_1(out["r41"]))
        out["r43"] = F.relu(self.conv4_1(out["r42"]))
        out["r44"] = F.relu(self.conv4_1(out["r43"]))
        out["p4"] = self.pool3(out["r44"])

        out["r51"] = F.relu(self.conv3_1(out["p4"]))
        out["r52"] = F.relu(self.conv3_1(out["r51"]))
        out["r53"] = F.relu(self.conv3_1(out["r52"]))
        out["r54"] = F.relu(self.conv3_1(out["r53"]))

        return out
    
    def get_activations(self, out, mode):
        if mode == "content":
            return [out[layer] for layer in self.content_acts]
        elif mode == "style":
            return [out[layer] for layer in self.style_acts]

TODO: Explain style content functions below


In [12]:
def gram_matrix(feature_maps):
    N, C, H, W = feature_maps.shape
    flat = feature_maps.reshape(N, C, H*W)
    flat_t = flat.permute(0, 2, 1)
    gram = torch.bmm(flat, flat_t)
    gram.div_(H * W)  # normalize by image size
    return gram

def style_loss(gen_grams, orig_grams):
    batch_size = len(gen_grams)
    return F.mse_loss(gen_grams, orig_grams) / batch_size

In [13]:
image_size = 512  # see section 4 of paper
BGR_means = [0.40760392, 0.45795686, 0.48501961]

pre = T.Compose([
    T.Resize(image_size),
    T.ToTensor(),
    T.Lambda(lambda x: x[torch.LongTensor([2,1,0])]),
    T.Normalize(mean=BGR_means, std=[1,1,1]),
    T.Lambda(lambda x : x.mul_(255.)),
])

post = T.Compose([
    T.Lambda(lambda x : x.div_(255.)),
    T.Normalize(mean=[-i for i in BGR_means], std=[1,1,1]),
    T.Lambda(lambda x: x[torch.LongTensor([2,1,0])]),
])

In [14]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

Mounted at /content/gdrive


In [17]:
# load images
style_img = Image.open(style_dir + "vasjen_katro.png").convert("RGB")
content_img = Image.open(content_dir + "gits.jpg").convert("RGB")
images = (style_img, content_img)

plt.imshow(np.asarray(style_img))
plt.show()

plt.imshow(np.asarray(content_img))
plt.show()

FileNotFoundError: ignored

ls: cannot access 'style_dir': No such file or directory


In [None]:
from google.colab import drive
drive.mount('drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at drive


In [None]:
# pre-process images
imgs_torch = [pre(img) for img in images]
if USE_GPU:
    imgs_torch = [img.unsqueeze(0).cuda() for img in imgs_torch]
else:
    imgs_torch = [img.unsqueeze(0) for img in imgs_torch]

# unpack processed images
style_img, content_img = imgs_torch

In [None]:
#get network
vgg = VGG()
vgg.load_state_dict(torch.load(model_dir + 'vgg_conv.pth'))
for param in vgg.parameters():
    param.requires_grad = False
if torch.cuda.is_available():
    vgg.cuda()

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
a = torch.arange(16).reshape(2,2,2,2)

a = a.reshape(2, 2, 4)
b = a.permute(0, 2, 1)
ab = torch.bmm(a, b)
print(ab.shape)
print(np.prod((ab.shape)))

torch.Size([2, 2, 2])
8


In [None]:
a = torch.arange(8, dtype=dtype).reshape(2,2,2)
b = torch.arange(8, 16, dtype=dtype).reshape(2,2,2)
F.mse_loss(a, b)

tensor(64.)

In [None]:
a = torch.arange(60, dtype=dtype).reshape(3,4,5)
foo = T.Compose([
    #T.Lambda(lambda x : x.div_(2.)),
    #T.Normalize(mean=BGR_means, std=[1,1,1]),
    T.Lambda(lambda x : x.T)
])

bar = T.Compose([
    T.Lambda(lambda x: x[torch.LongTensor([2,1,0])]),
])
print(a)
print(foo(a))
print(bar(a))

tensor([[[ 0.,  1.,  2.,  3.,  4.],
         [ 5.,  6.,  7.,  8.,  9.],
         [10., 11., 12., 13., 14.],
         [15., 16., 17., 18., 19.]],

        [[20., 21., 22., 23., 24.],
         [25., 26., 27., 28., 29.],
         [30., 31., 32., 33., 34.],
         [35., 36., 37., 38., 39.]],

        [[40., 41., 42., 43., 44.],
         [45., 46., 47., 48., 49.],
         [50., 51., 52., 53., 54.],
         [55., 56., 57., 58., 59.]]])
tensor([[[ 0., 20., 40.],
         [ 5., 25., 45.],
         [10., 30., 50.],
         [15., 35., 55.]],

        [[ 1., 21., 41.],
         [ 6., 26., 46.],
         [11., 31., 51.],
         [16., 36., 56.]],

        [[ 2., 22., 42.],
         [ 7., 27., 47.],
         [12., 32., 52.],
         [17., 37., 57.]],

        [[ 3., 23., 43.],
         [ 8., 28., 48.],
         [13., 33., 53.],
         [18., 38., 58.]],

        [[ 4., 24., 44.],
         [ 9., 29., 49.],
         [14., 34., 54.],
         [19., 39., 59.]]])
tensor([[[40., 41., 42., 43., 44