In [1]:
%matplotlib inline
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.optim as optim
from torchvision import transforms,models
import time

In [2]:
vggList = [models.vgg11(pretrained=True).features,models.vgg13(pretrained=True).features,models.vgg16(pretrained=True).features,models.vgg19(pretrained=True).features]
devices=torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(devices)

for vgg in vggList:
    
    for param in vgg.parameters():
        param.requires_grad_(False)
        
    vgg.to(devices)


cuda


In [3]:
def load_image(img_path,max_size=400,shape=None):
    image=Image.open(img_path).convert('RGB')
    if max(image.size) > max_size:
        size=max_size
    else:
        size=max(image.size)
    if shape is not None:
        size=shape
    
    in_transform=transforms.Compose([transforms.Resize(size),transforms.ToTensor(),transforms.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225))])
    
    image = in_transform(image)[:3,:,:].unsqueeze(0)
    return image

path = "images/"
contentListName = ['space_needle.jpg','space_needle.jpg','space_needle.jpg']
contentList = list()
styleListName = ['space_needle.jpg','space_needle.jpg','space_needle.jpg']
styleList = list()

for name in contentListName:
    
    contentList.append(load_image(path + name).to(devices))
    
for name in styleListName:
    
    styleList.append(load_image(path + name).to(devices))
    #styleList.append(load_image(path + name,shape=content.shape[-2:]).to(devices))


In [4]:
def style_loss(Y_hat, gram_Y):
    return np.square(gram(Y_hat) - gram_Y).mean()

def content_loss(Y_hat, Y):
    return np.square(Y_hat - Y).mean()

def gram_matrix(tensor):
    _,d,h,w=tensor.size()
    tensor=tensor.view(d,h*w)
    gram=torch.mm(tensor,tensor.t())
    return gram

def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):
    # Calculate the content, style, and total variance losses respectively
    contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(
        contents_Y_hat, contents_Y)]
    styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(
        styles_Y_hat, styles_Y_gram)]
    tv_l = tv_loss(X) * tv_weight
    # Add up all the losses
    l = sum(styles_l + contents_l + [tv_l])
    return contents_l, styles_l, tv_l, l

def showDiffrenses(content,target):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
    ax1.imshow(im_convert(content))
    ax2.imshow(im_convert(target))


In [None]:

content_features = list()
content_features.append(get_features(content,vgg))

style_features=get_features(style,vgg)

style_grams={layer: gram_matrix(style_features[layer]) for layer in style_features}


target=content.clone().requires_grad_(True).to(devices)

style_weights=

style_weightsList = list()

style_weightsList.append()
style_weightsList.append()
style_weightsList.append()
style_weightsList.append({'conv1_1':1.,
               'conv2_1':0.8,
               'conv3_1':0.5,
               'conv4_1':0.3,
               'conv5_1':0.1,
               'conv6_1':0.1})
target_weightsList = list()

target_weightsList.append({'conv2_2':1})
target_weightsList.append({'conv3_2':1})
target_weightsList.append({'conv4_2':1})
target_weightsList.append({'conv5_2':1})

content_weight=1
style_weight=1e6

In [8]:
def create(vgg,target,convLayerKey,styleWeights,contentFeatures,steps=2000,showEvery=400,showTime=True,showLoss=True):
    
    if(show_time):
        start = time.time()
    
    optimizer=optim.Adam([target],lr=0.003)
    
    for ii in range(1,steps+1):
        targetFeatures = get_features(target,vgg)
        contentLoss = torch.mean((targetFeatures[convLayerKey] - contentFeatures[convLayerKey])**2)

        style_loss = 0
        for layer in styleWeights:
            targetFeature = targetFeatures[layer]
            targetGram = gram_matrix(targetFeature)
            _,d,h,w = target_feature.shape
            style_gram = style_grams[layer]
            layerStyleLoss = styleWeights[layer] * torch.mean((targetGram-style_gram)**2)
            style_loss += layerStyleLoss/(d*h*w)

        total_loss=content_weight*contentLoss+style_weight*style_loss

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        if ii% show_every==0:
            if (showLoss):
                print('Content loss:',contentLoss)
                print('Style loss:',style_loss)
                print('Total loss:',total_loss.item())
            plt.imshow(im_convert(target))
            plt.show()
            if(show_time):
                print('Time:',time.time() - start)
                
    if(show_time):            
        print(time.time() - start)
        
    return target

# vgg11

In [None]:
 target = create(vgg(0),)

# vgg13

# vgg16

# vgg19