In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Get_gradient(nn.Module):
    def __init__(self):
        super(Get_gradient, self).__init__()
        kernel_v = [[0, -1, 0], 
                    [0, 0, 0], 
                    [0, 1, 0]]
        kernel_h = [[0, 0, 0], 
                    [-1, 0, 1], 
                    [0, 0, 0]]
        kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
        kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
        self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False).cpu()
        self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False).cpu()

    def forward(self, x, input_nc=3):
        if input_nc==3:
            x0 = x[:, 0]
            x1 = x[:, 1]
            x2 = x[:, 2]
            x0_v = F.conv2d(x0.unsqueeze(1), self.weight_v, padding=1)
            x0_h = F.conv2d(x0.unsqueeze(1), self.weight_h, padding=1)

            x1_v = F.conv2d(x1.unsqueeze(1), self.weight_v, padding=1)
            x1_h = F.conv2d(x1.unsqueeze(1), self.weight_h, padding=1)

            x2_v = F.conv2d(x2.unsqueeze(1), self.weight_v, padding=1)
            x2_h = F.conv2d(x2.unsqueeze(1), self.weight_h, padding=1)

            x0 = torch.sqrt(torch.pow(x0_v, 2) + torch.pow(x0_h, 2) + 1e-6)
            x1 = torch.sqrt(torch.pow(x1_v, 2) + torch.pow(x1_h, 2) + 1e-6)
            x2 = torch.sqrt(torch.pow(x2_v, 2) + torch.pow(x2_h, 2) + 1e-6)

            x = torch.cat([x0, x1, x2], dim=1)
        elif input_nc==1:
            x = x[:,0]
            x_v = F.conv2d(x.unsqueeze(1), self.weight_v, padding=1)
            x_h = F.conv2d(x.unsqueeze(1), self.weight_h, padding=1)
            x = torch.sqrt(torch.pow(x_v, 2) + torch.pow(x_h, 2) + 1e-6)
        return x

input = torch.zeros(1,255,255,1)
model = Get_gradient()
model(input).shape

torch.Size([1, 3, 255, 1])

: 

In [None]:
from vit_pytorch import ViT
import torch


model = ViT(
        image_size=256,
        patch_size=32,
        num_classes=2,
        dim=1024,
        depth=6,
        heads=16,
        mlp_dim=2048,
        dropout=0.1,
        emb_dropout=0.1
)


In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

import copy

class ContentLoss(nn.Module):
    
    def __init__(self, target,):
        super(ContentLoss, self).__init__()
        # we 'detach' the target content from the tree used
        # to dynamically compute the gradient: this is a stated value,
        # not a variable. Otherwise the forward method of the criterion
        # will throw an error.
        self.target = target.detach()

    def forward(self, input):
        self.loss = F.mse_loss(input, self.target)
        return input

def gram_matrix(input):
    a, b, c, d = input.size()  # a=batch size(=1)
    # b=number of feature maps
    # (c,d)=dimensions of a f. map (N=c*d)

    features = input.view(a * b, c * d)  # resise F_XL into \hat F_XL

    G = torch.mm(features, features.t())  # compute the gram product

    # we 'normalize' the values of the gram matrix
    # by dividing by the number of element in each feature maps.
    return G.div(a * b * c * d)

class StyleLoss(nn.Module):
    
    def __init__(self, target_feature):
        super(StyleLoss, self).__init__()
        self.target = gram_matrix(target_feature).detach()

    def forward(self, input):
        G = gram_matrix(input)
        self.loss = F.mse_loss(G, self.target)
        return input

class Normalization(nn.Module):
    def __init__(self, mean, std):
        super(Normalization, self).__init__()
        self.mean = torch.tensor(mean).view(-1, 1, 1)
        self.std = torch.tensor(std).view(-1, 1, 1)

    def forward(self, img):
        # normalize img
        return (img - self.mean) / self.std

def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
                               style_img,
                               style_layers=['conv_2', 'conv_4', 'conv_7', 'conv_10']):
    # At runtime, CNN is a pretrained VGG19 CNN network.
    cnn = copy.deepcopy(cnn)

    normalization = Normalization(normalization_mean, normalization_std)
    model = nn.Sequential(normalization)

    content_losses = []
    style_losses = []

    i = 0  # increment every time we see a conv
    for layer in cnn.children():
        # The first layer simply puts names to things and replaces ReLU inplace
        # (which is optimized) with ReLU reallocated. This is a small optimization
        # being removed, and hence a small performance penalty, necessitated by
        # ContentLoss and StyleLoss not working well when inplace=True.
        if isinstance(layer, nn.Conv2d):
            i += 1
            name = 'conv_{}'.format(i)
        elif isinstance(layer, nn.ReLU):
            name = 'relu_{}'.format(i)
            layer = nn.ReLU(inplace=False)
        elif isinstance(layer, nn.MaxPool2d):
            name = 'pool_{}'.format(i)
        elif isinstance(layer, nn.BatchNorm2d):
            name = 'bn_{}'.format(i)
        else:
            raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))

        # add_module is a setter that is pretty much a setattr equivalent, used for
        # registering the layer with PyTorch.
        model.add_module(name, layer)

        if name in style_layers:
            target_feature = model(style_img).detach()
            style_loss = StyleLoss(target_feature)
            model.add_module("style_loss_{}".format(i), style_loss)
            style_losses.append(style_loss)

    # Trim off the layers after the last content and style losses
    for i in range(len(model) - 1, -1, -1):
        if isinstance(model[i], StyleLoss):
            break

    model = model[:(i + 1)]

    return model, style_losses

In [None]:
img = torch.zeros(1,3,256,256)
cnn = models.vgg19(pretrained=True).features.eval()
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406])
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225])
temp_model, style_losses = get_style_model_and_losses(
    cnn, cnn_normalization_mean, cnn_normalization_std,
    img
)

input = torch.ones(1,3,256,256)
input.data.clamp_(0,1)
style_score = []
x = input
# for i in range(len(temp_model)):
#     x = temp_model[i](x)
#     if i in [4]:
#         print(temp_model[i])
#         print(temp_model[i](x).loss)
#         style_score.append(x)
temp_model(input)
style_score = 0
for sl in style_losses:
    style_score += sl.loss


In [None]:
temp_model