In [1]:
import time
import os 
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import torchvision
from torchvision import transforms
from PIL import Image
from collections import OrderedDict
import matplotlib.pyplot as plt

In [2]:
class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        self.conv1_1 = nn.Conv2d(3, 64, kernel_size = 3, padding = 1)
        self.conv1_2 = nn.Conv2d(64, 64, kernel_size = 3, padding = 1)
        
        self.conv2_1 = nn.Conv2d(64, 128, kernel_size = 3, padding = 1)
        self.conv2_2 = nn.Conv2d(128, 128, kernel_size = 3, padding = 1)
        
        self.conv3_1 = nn.Conv2d(128, 256, kernel_size = 3, padding = 1)
        self.conv3_2 = nn.Conv2d(256, 256, kernel_size = 3, padding = 1)
        self.conv3_3 = nn.Conv2d(256, 256, kernel_size = 3, padding = 1)
        self.conv3_4 = nn.Conv2d(256, 256, kernel_size = 3, padding = 1)
        
        self.conv4_1 = nn.Conv2d(256, 512, kernel_size = 3, padding = 1)
        self.conv4_2 = nn.Conv2d(512, 512, kernel_size = 3, padding = 1)
        self.conv4_3 = nn.Conv2d(512, 512, kernel_size = 3, padding = 1)
        self.conv4_4 = nn.Conv2d(512, 512, kernel_size = 3, padding = 1)
        
        self.conv5_1 = nn.Conv2d(512, 512, kernel_size = 3, padding = 1)
        self.conv5_2 = nn.Conv2d(512, 512, kernel_size = 3, padding = 1)
        self.conv5_3 = nn.Conv2d(512, 512, kernel_size = 3, padding = 1)
        self.conv5_4 = nn.Conv2d(512, 512, kernel_size = 3, padding = 1)
        
        self.pool1 = nn.AvgPool2d(kernel_size = 2, stride = 2)
        self.pool2 = nn.AvgPool2d(kernel_size = 2, stride = 2)
        self.pool3 = nn.AvgPool2d(kernel_size = 2, stride = 2)
        self.pool4 = nn.AvgPool2d(kernel_size = 2, stride = 2)
        self.pool5 = nn.AvgPool2d(kernel_size = 2, stride = 2)
            

    def forward(self, x, out_keys):
        out = {}
        
        out['r11'] = F.relu(self.conv1_1(x))
        out['r12'] = F.relu(self.conv1_2(out['r11']))
        out['p1'] = self.pool1(out['r12'])
        
        out['r21'] = F.relu(self.conv2_1(out['p1']))
        out['r22'] = F.relu(self.conv2_2(out['r21']))
        out['p2'] = self.pool2(out['r22'])
        
        out['r31'] = F.relu(self.conv3_1(out['p2']))
        out['r32'] = F.relu(self.conv3_2(out['r31']))
        out['r33'] = F.relu(self.conv3_3(out['r32']))
        out['r34'] = F.relu(self.conv3_4(out['r33']))
        out['p3'] = self.pool3(out['r34'])
        
        out['r41'] = F.relu(self.conv4_1(out['p3']))
        out['r42'] = F.relu(self.conv4_2(out['r41']))
        out['r43'] = F.relu(self.conv4_3(out['r42']))
        out['r44'] = F.relu(self.conv4_4(out['r43']))
        out['p4'] = self.pool4(out['r44'])
        
        out['r51'] = F.relu(self.conv5_1(out['p4']))
        out['r52'] = F.relu(self.conv5_2(out['r51']))
        out['r53'] = F.relu(self.conv5_3(out['r52']))
        out['r54'] = F.relu(self.conv5_4(out['r53']))
        out['p5'] = self.pool5(out['r54'])
        
        
        return [out[key] for key in out_keys]

In [3]:
class GramMatrix(nn.Module):
    def forward(self, input):
        batchnum, channelnum, width, height = input.size()
        F = input.view(batchnum, channelnum, height * width)
        return torch.bmm(F, F.transpose(1, 2)).div(height*width)


class GramMSELoss(nn.Module):
    def forward(self, input, target):
        Gram = GramMatrix()
        out = nn.MSELoss()(Gram(input), target)
        return out

img_size = 512
prep = transforms.Compose([transforms.Resize(img_size),
                           transforms.ToTensor(),
                           transforms.Lambda(lambda x: x[torch.LongTensor([2, 1, 0])]),
                           transforms.Normalize(mean = [0.40760392, 0.45795686, 0.48501961], std = [1, 1, 1]),
                           transforms.Lambda(lambda x: x.mul(255)), ])
postpa = transforms.Compose([transforms.Lambda(lambda x: x.mul(1./255)),
                            transforms.Normalize(mean = [-0.40760392, -0.45795686, -0.48501961], std = [1, 1, 1]),
                            transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])]), ])
postpb = transforms.Compose([transforms.ToPILImage()])
def postp(tensor):
    t = postpa(tensor)
    t[t>1], t[t<0] = 1, 0
    img = postpb(t)
    return img

In [None]:
vgg = VGG()
vgg.load_state_dict(torch.load("vgg_conv_weights.pth"))            
for param in vgg.parameters():
    param.requires_grad = False
if torch.cuda.is_available():
    vgg.cuda()
img_names = ['style_monet_sunset.jpg', 'IMG_0995.JPG']
imgs = [Image.open("/content/gdrive/My Drive/cs194/StyleTransfer/Images/" + name) for name in img_names]
imgs_torch = [prep(img) for img in imgs]
if torch.cuda.is_available():
    imgs_torch = [Variable(img.unsqueeze(0)).cuda() for img in imgs_torch]
else:
    imgs_torch = [Variable(img.unsqueeze(0)) for img in imgs_torch]
style_img, content_img = imgs_torch
opt_img = Variable(content_img.clone(), requires_grad = True)
style_layers = ['r11', 'r12', 'r31', 'r41', 'r51']
content_layers = ['r42']
loss_layers = style_layers + content_layers
loss_fns = [GramMSELoss()] * len(style_layers) + [nn.MSELoss()] * len(content_layers)
if torch.cuda.is_available():
    loss_fns = [loss_fn.cuda() for loss_fn in loss_fns] 
style_weights = [1e3/n**2 for n in [64, 128, 256, 512, 512]]
content_weights = [1e0]
weights = style_weights + content_weights
style_targets = [GramMatrix()(A).detach() for A in vgg(style_img, style_layers)]
content_targets = [A.detach() for A in vgg(content_img, content_layers)]
targets = style_targets + content_targets

In [None]:
max_iter = 500
show_iter = 30
optimizer = optim.LBFGS([opt_img])
print(opt_img.size())
print(content_img.size())
n_iter = [0]
while n_iter[0] <= max_iter:
  
    def closure():
        optimizer.zero_grad()
        out = vgg(opt_img, loss_layers)
        layer_losses = [weights[a] * loss_fns[a](A, targets[a]) for a,A in enumerate(out)]
        loss = sum(layer_losses)
        loss.backward()
        n_iter[0] += 1
        if n_iter[0] % show_iter == (show_iter - 1):
            print('Iteration: %d,\tLoss: %f' % (n_iter[0] + 1, loss.item()))
            
        return loss
    
    optimizer.step(closure)

torch.Size([1, 3, 512, 682])
torch.Size([1, 3, 512, 682])
Iteration: 30,	Loss: 357097.281250
Iteration: 60,	Loss: 66715.140625
Iteration: 90,	Loss: 50542.398438
Iteration: 120,	Loss: 45328.500000
Iteration: 150,	Loss: 42671.773438
Iteration: 180,	Loss: 40939.046875
Iteration: 210,	Loss: 39777.316406
Iteration: 240,	Loss: 38941.500000
Iteration: 270,	Loss: 38325.320312
Iteration: 300,	Loss: 37881.812500
Iteration: 330,	Loss: 37527.433594
Iteration: 360,	Loss: 37239.613281
Iteration: 390,	Loss: 37009.386719
Iteration: 420,	Loss: 36818.699219
Iteration: 450,	Loss: 36657.097656
Iteration: 480,	Loss: 36518.562500
Iteration: 510,	Loss: 36406.750000


In [None]:
out_img = postp(opt_img.data[0].cpu().squeeze())
# plt.grid(None)
# plt.imshow(out_img)
# plt.gcf().set_size_inches(10, 10)

In [None]:
import skimage.io as skio
import numpy as np
skio.imsave("a.jpg", np.asarray(out_img))

In [None]:
skio.imsave("/content/gdrive/My Drive/cs194/StyleTransfer/Images/claire2.JPG", np.asarray(out_img))