In [1]:
import time
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import torchvision
from torchvision import transforms
from io import BytesIO
from PIL import Image
from collections import OrderedDict
#from google.colab import files

In [4]:
class VGG16(nn.Module):
    def __init__(self, pool='max'):
        super().__init__()
        self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv4_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv5_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        if pool == 'max':
            self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
            self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
            self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
            self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
            self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
        elif pool == 'avg':
            self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
            self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
            self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2)
            self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2)
            self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2)
            
    def forward(self, x, layers):
        out = {}
        out['relu1_1'] = F.relu(self.conv1_1(x))
        out['relu1_2'] = F.relu(self.conv1_2(out['relu1_1']))
        out['pool1'] = self.pool1(out['relu1_2'])
        out['relu2_1'] = F.relu(self.conv2_1(out['pool1']))
        out['relu2_2'] = F.relu(self.conv2_2(out['relu2_1']))
        out['pool2'] = self.pool2(out['relu2_2'])
        out['relu3_1'] = F.relu(self.conv3_1(out['pool2']))
        out['relu3_2'] = F.relu(self.conv3_2(out['relu3_1']))
        out['relu3_3'] = F.relu(self.conv3_3(out['relu3_2']))
        out['relu3_4'] = F.relu(self.conv3_4(out['relu3_3']))
        out['pool3'] = self.pool3(out['relu3_4'])
        out['relu4_1'] = F.relu(self.conv4_1(out['pool3']))
        out['relu4_2'] = F.relu(self.conv4_2(out['relu4_1']))
        out['relu4_3'] = F.relu(self.conv4_3(out['relu4_2']))
        out['relu4_4'] = F.relu(self.conv4_4(out['relu4_3']))
        out['pool4'] = self.pool4(out['relu4_4'])
        out['relu5_1'] = F.relu(self.conv5_1(out['pool4']))
        out['relu5_2'] = F.relu(self.conv5_2(out['relu5_1']))
        out['relu5_3'] = F.relu(self.conv5_3(out['relu5_2']))
        out['relu5_4'] = F.relu(self.conv5_4(out['relu5_3']))
        out['pool5'] = self.pool5(out['relu5_4'])
        return [out[key] for key in layers]


In [None]:
vgg = VGG16()
vgg.load_state_dict(torch.load('vgg_conv.pth'))
for param in vgg.parameters():
    param.requires_grad = False
if torch.cuda.is_available():
    vgg.cuda()

In [6]:
class GramMatrix(nn.Module):
    def forward(self, input):
        b,c,h,w = input.size()
        F = input.view(b, c, h*w)
        G = torch.bmm(F, F.transpose(1,2)) 
        G.div_(h*w)
        return G
class GramMSELoss(nn.Module):
    def forward(self, input, target):
        out = nn.MSELoss()(GramMatrix()(input), target)
        return out