In [44]:
import torch
from torch import nn
import torchvision

from PIL import Image

In [45]:
content_img = Image.open('final_style.png')
style_img = Image.open('final_content.png')

In [46]:
## Preprocessing

rgb_mean = torch.tensor([0.485, 0.456, 0.406])
rgb_std = torch.tensor([0.229, 0.224, 0.225])

def preprocess(img,image_shape):
    transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize(image_shape),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=rgb_mean,std=rgb_std)
    ])
    return transforms(img).unsqueeze(0)

def postprocess(img):
    img = img[0].to(rgb_std.device)
    img = torch.clamp(img.permute(1,2,0)*rgb_std+rgb_mean,0,1)
    return torchvision.transforms.ToPILImage()(img.permute(2,0,1))

In [47]:
from lucent.modelzoo.inceptionv1 import InceptionV1
from lucent.misc.io import show

import lucent.optvis.objectives as objectives
import lucent.optvis.param as param
import lucent.optvis.render as render

In [48]:
model = InceptionV1()

In [49]:
style_layers = [
  'conv2d2',
  'mixed3a',
  'mixed4a',
  'mixed4b',
  'mixed4c',
]

content_layers = [
  'mixed3b',
]

In [68]:
image_shape =  (250, 325) 

def style_transfer_param(content_image, style_image, decorrelate=True, fft=True):
    
    style_transfer_input = param.image(*preprocess(content_image,image_shape).shape[:2], decorrelate=decorrelate, fft=fft)[0]
    content_input = content_image
#     print(style_transfer_input)
    return style_transfer_input, content_input, style_image

# these constants help remember which image is at which batch dimension
TRANSFER_INDEX = 0
CONTENT_INDEX = 1
STYLE_INDEX = 2

In [69]:
style_transfer_param(content_img, style_img)

([tensor([[[[[ 0.0051,  0.0295]],
  
            [[-0.0161, -0.0074]],
  
            [[-0.0209,  0.0009]]],
  
  
           [[[-0.0036, -0.0135]],
  
            [[ 0.0133,  0.0056]],
  
            [[ 0.0128,  0.0058]]],
  
  
           [[[ 0.0051, -0.0207]],
  
            [[-0.0102,  0.0028]],
  
            [[-0.0168, -0.0018]]]]], device='cuda:0', requires_grad=True)],
 <PIL.PngImagePlugin.PngImageFile image mode=RGB size=512x512 at 0x22167C8DA88>,
 <PIL.PngImagePlugin.PngImageFile image mode=RGB size=645x512 at 0x22167C8D448>)

In [70]:
def mean_L1(a, b):
    return torch.mean(torch.abs(a-b))

In [75]:
def activation_difference(layer_names, activation_loss_f=mean_L1, transform_f=None, difference_to=CONTENT_INDEX):
    def inner(T):
        # first we collect the (constant) activations of image we're computing the difference to
        image_activations = [T(layer_name)[difference_to] for layer_name in layer_names]
        if transform_f is not None:
            image_activations = [transform_f(act) for act in image_activations]

        # we also set get the activations of the optimized image which will change during optimization
        optimization_activations = [T(layer)[TRANSFER_INDEX] for layer in layer_names]
        if transform_f is not None:
            optimization_activations = [transform_f(act) for act in optimization_activations]

        # we use the supplied loss function to compute the actual losses
        losses = [activation_loss_f(a, b) for a, b in zip(image_activations, optimization_activations)]
        return tf.add_n(losses)

    return inner

In [76]:
# def gram_matrix(array, normalize_magnitue=True):
#     channels = tf.shape(array)[-1]
#     array_flat = tf.reshape(array, [-1, channels])
#     gram_matrix = tf.matmul(array_flat, array_flat, transpose_a=True)
#     if normalize_magnitue:
#     length = tf.shape(array_flat)[0]
#     gram_matrix /= tf.cast(length, tf.float32)
#     return gram_matrix

def gram_matrix(X):
    num_channels, n = X.shape[1], X.numel() // X.shape[1]
    X = X.reshape((num_channels, n))
    return torch.matmul(X, X.T) / (num_channels * n)

In [77]:
param_f = lambda: style_transfer_param(content_image, style_image)

content_obj = 100 * activation_difference(content_layers, difference_to=CONTENT_INDEX)
content_obj.description = "Content Loss"

style_obj = activation_difference(style_layers, transform_f=gram_matrix, difference_to=STYLE_INDEX)
style_obj.description = "Style Loss"

objective = - content_obj - style_obj

vis = render.render_vis(model, objective, param_f=param_f, thresholds=[512], verbose=False, print_objectives=[content_obj, style_obj])[-1]

TypeError: unsupported operand type(s) for *: 'int' and 'function'