In [1]:
import torch
from torch import nn
from collections import namedtuple

import numpy as np
from torchvision import models
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader

In [2]:
class Vgg16(torch.nn.Module):
    def __init__(self,requires_grad=False):
        super(Vgg16,self).__init__()
        vgg_features = models.vgg16(pretrained=True).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        for x in range(4):
            self.slice1.add_module(str(x),vgg_features[x])
        for x in range(4, 9):
            self.slice2.add_module(str(x),vgg_features[x])
        for x in range(9, 16):
            self.slice3.add_module(str(x), vgg_features[x])
        for x in range(16, 23):
            self.slice4.add_module(str(x), vgg_features[x])
            
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False
                
    def forward(self, X):
        h = self.slice1(X)
        h_relu1_2 = h
        h = self.slice2(h)
        h_relu2_2 = h
        h = self.slice3(h)
        h_relu3_3 = h
        h = self.slice4(h)
        h_relu4_3 = h
        vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3'])
        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3)
        return out
    
def normalize_batch(batch):
    # normalize using imagenet mean and std
    mean = batch.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
    std = batch.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
    batch = batch.div_(255.0)
    return (batch - mean) / std

def gram_matrix(y):
    (b, ch, h, w) = y.size()
    features = y.view(b, ch, w * h)
    features_t = features.transpose(1, 2)
    gram = features.bmm(features_t) / (ch * h * w)
    return gram

In [3]:
style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
])

train_dataset = datasets.ImageFolder("d:/Styles/", style_transform)
style_imgs = [x[0].unsqueeze(0).cuda() for x in train_dataset]

In [4]:
vgg = Vgg16(requires_grad=False).to("cuda")

In [5]:
with torch.no_grad():
    style_fs = [vgg(normalize_batch(style)) for style in style_imgs]

In [6]:
grams = [] #images*4
for f in style_fs:
    tmp = [None]*4
    for k in range(4):
        tmp[k] = gram_matrix(f[k])
    grams.append(tmp)

In [7]:
avg = [] #style_layers * batch_size
for k in range(4):
    val = []
    for g in grams:
        val.append(g[k])
    avg.append(torch.cat(val).mean(0).repeat(4,1,1))

In [9]:
len(avg[0])

4

In [None]:
f[3].size()