In [80]:
from torch.nn import functional as F
import torch
import time
import numpy as np

Notebook that goes through the feature processing that is used to create the loss funciton

In [83]:
def calc_var_loss(features):
    epsilon = 1e-6
    variances = []
    for feature in features:
        dim_std = torch.sqrt(torch.var(feature, dim=(0,1)) + epsilon)
        var_loss = torch.mean(F.relu(1 - dim_std))
        variances.append(var_loss)
    return torch.tensor(variances).mean()

def calc_inv_loss(features):
    epsilon = 1e-6
    inv_losss = []
    for i in range(len(features)-1):
        inv_loss = F.mse_loss(features[i], features[i+1])
        inv_losss.append(inv_loss)
    return torch.tensor(inv_losss).mean()

def calc_cov_loss(features):
    cov_losss = []
    N1, N2, D = features[0].shape
    for feature in features:
        feature = feature.view(N1*N2, D)
        norm_feat = feature - feature.mean(dim=0)
        cov_feat = ((norm_feat.T @ norm_feat) / (N1 * N2 - 1)).square()
        cov_loss = (cov_feat.sum() - cov_feat.diagonal().sum()) / D
        cov_losss.append(cov_loss)
    return torch.tensor(cov_losss).mean()
    


Test the time that it takes to run on a batch


In [84]:
features = [torch.rand(size=(8, 12, 512)) for _ in range(16)]

t_0 = time.time()
calc_var_loss(features)
t_1 = time.time()
calc_inv_loss(features)
t_2 = time.time()
calc_cov_loss(features) #Note: takes twice as long as var_loss
t_3 = time.time()

print(f"var_loss: {t_1-t_0}, inv_loss: {t_2-t_1}, conv_loss: {t_3-t_2}")

var_loss: 0.004872322082519531, inv_loss: 0.0029745101928710938, conv_loss: 0.011105775833129883


In [87]:
hand_test = torch.tensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
features = [torch.tensor([[[1., 2., 3.], [4., 5., 6., ]], [[7., 8., 9.], [10., 11., 12.]]]) for _ in range(5)]
print(calc_var_loss(features))
print(calc_inv_loss(features))

features = [torch.tensor([[[1., 2.], [4., 5. ]], [[7., 8.], [10., 11.]]]) for _ in range(5)]
print(calc_cov_loss(features))


tensor(0.)
tensor(0.)
tensor(225.)
