In [16]:
#imports
import numpy as np
import torch
from torch import nn
import torchvision

In [20]:
style = torch.rand((3,640,360))
lmbda = 1e4
alpha = 1
beta = 10
gamma = 1e-3
c_layers = 22 #ReLU4_2
s_layers = [3,8,13,22] #ReLU1_2, ReLU2_2, ReLU3_2, ReLU4_2
eta = 1

In [18]:
class HybridLoss(nn.Module):
    
    def __init__(self, style, lmbda, alpha, beta, gamma, c_layers, s_layers, eta):
        super(ContentLoss, self).__init__()
        self.style = style
        self.lmbda = lmbda
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.c_layers = c_layers
        self.s_layers = s_layers
        self.eta = eta
        

    def forward(self, input, output, input_prev, output_prev):
        return hybrid_loss(input, output, input_prev, output_prev, style=self.style, 
                           lmbda=self.lmbda, alpha=self.alpha, beta=self.beta, gamma=self.gamma, 
                           c_layers=self.c_layers, s_layers=self.s_layers, eta=self.eta)

In [2]:
def hybrid_loss(input, output, input_prev, output_prev, style, lmbda, alpha, beta, gamma, c_layers, s_layers, eta):
    sp_loss = spatial_loss(input, output, style, alpha, beta, gamma, c_layers, s_layers, eta)
    sp_loss_prev = spatial_loss(input_prev, output_prev, style, alpha, beta, gamma, c_layers, s_layers, eta)
    sp_loss_total = sp_loss + sp_loss_prev
    tem_loss = temporal_loss(output, output_prev)
    
    return sp_loss + lmbda*tem_loss

In [3]:
def spatial_loss(input, output, style, alpha, beta, gamma, c_layers, s_layers, eta):
    feat_net = extractor_net()
    feats_in = extract_features(input, feat_net)
    feats_out = extract_features(output, feat_net)
    content_acc = 0.0
    for l in c_layers:
        content_acc += content_loss(feats_in[l], feats_out[l])
        
    style_acc = 0.0
    for l in s_layers:
        style_acc += style_loss(style, output)
        
    reg = tv_reg(output, eta)
    
    return alpha*content_acc + beta*style_acc + gamma*reg
    

In [6]:
# Feature extractor net
def extractor_net():
    net = torchvision.models.vgg19(pretrained=True).features.eval()
    for param in net.parameters():
        param.requires_grad = False
    
    return net

In [22]:
test = torch.rand((5,3,640,360))
test_out = torch.rand((5,3,640,360))
test.shape

torch.Size([5, 3, 640, 360])

In [35]:
content_loss(test,test_out)

tensor(0.1666)

In [36]:
acc = 0.0
for i in range(5):
    acc+=content_loss(test[i].view(1,3,640,360),test_out[i].view(1,3,640,360))
acc

tensor(0.8330)

In [11]:
feats=extract_features(test, extractor_net())

In [13]:
len(feats)

37

In [21]:
# torchvision.models.vgg19(pretrained=True)

In [7]:
# Feature extractor layer util function
def extract_features(x, net):
    features = []
    feat = x
    for module in net._modules.values():
        feat = module(feat)
        features.append(feat)
        
    return features

In [34]:
def content_loss(input, output):
    N, C, H, W = input.size()
#     loss = torch.zeros(N)
#     for i in range(N):
#         loss[i] = torch.sum(torch.pow(input[i]-output[i],2))
    loss = torch.sum(torch.pow(input-output,2))
    loss = loss/float(N*C*H*W)
    
    return loss

In [12]:
# Gram matrix util function
def gram_matrix(features):
    N, C, H, W = features.size()
    features = features.view(N,C,H*W)
    gram_matrix = torch.zeros([N,C,C])
    for i in range(N):
        gram_matrix[i,:] = torch.matmul(features[i,:],features[i,:].t())
    gram_matrix = gram_matrix/float(H*W)
    
    return gram_matrix

In [13]:
def style_loss(style, output):
    style_gram = gram_matrix(style)
    output_gram = gram_matrix(output)
    N, C, C = output_gram.size()
    loss = 0
    for i in range(N):
        loss += torch.pow(torch.norm(style_gram[i,:]-output_gram[i,:]), 2)
    loss = loss/float(C**2)
    
    return loss

In [14]:
def tv_reg(frame, eta):
    N, C, H, W = frame.size()
    horz = torch.sum(torch.pow(img[:,:,1:,:]-img[:,:,:-1,:], 2))
    vert = torch.sum(torch.pow(img[:,:,:,1:]-img[:,:,:,:-1], 2))
    loss = torch.pow(vert+horz, eta/2)
    
    return loss 

In [None]:
# Optical Flow Net
# Use flow net 2


In [None]:
# Optical Flow util function
def optical_flow(frame_prev, frame):
    pass

In [None]:
def temporal_loss(frame, frame_prev):
    # D = H x W x C
#     warp, c = optical_flow(frame_prev, frame)
#     return (1/D)*to2rch.sum(c*(frame-warp)**2)
    return 0