In [None]:
#### Defining loss functions ####
from torch.nn.functional import binary_cross_entropy
from torch import optim
import math

def Gaussian_density(sample_img,mu_img,log_var_img):
    c = - 0.5 * math.log(2 * math.pi)
    density = c - log_var_img/2 - (sample_img - mu_img)**2/(2 * torch.exp(log_var_img))
    #density = c -  (sample_img - mu_img)**2/(2 * torch.exp(log_var_img))
    #print("Density:",density)
    #print("Density.shape:", density.shape)
    return torch.sum(density,dim = 1) # Sum over channels

def kl_a_calc(q_a,q_mu, q_log_var,p_mu, p_log_var):
    # The function assumes: 
        # q_a has dimension: [batch_size,No_samples,latent_features]
        # q_mu/log_var has dimension: [batch_size,latent_features]
        # p_mu/log_var has dimension: [batch_size,No_samples,latent_features]
        
    p_mu      = p_mu.view(batch_size,num_samples,-1)
    p_mu      = torch.mean(p_mu, dim = 1)
    p_log_var = p_log_var.view(batch_size,num_samples,-1)
    p_log_var = torch.mean(p_log_var, dim = 1)
    
    def log_gaussian(x, mu, log_var):
        log_pdf = - 0.5 * math.log(2 * math.pi) - log_var / 2 - (x - mu)**2 / (2 * torch.exp(log_var))
        log_pdf = torch.sum(log_pdf, dim=1) # sum over each over the observations (mu + log_var*epsilon)
        log_pdf = torch.sum(log_pdf,dim=1) # sum over q_a, i.e. latent features
        return log_pdf

    # put in middle dimension
    q_mu      = q_mu.unsqueeze(1)
    q_log_var = q_log_var.unsqueeze(1)
    p_mu      = p_mu.unsqueeze_(1)
    p_log_var = p_log_var.unsqueeze_(1)
    # densities of each disitribution 
    qz = log_gaussian(q_a,q_mu,q_log_var)
    pz = log_gaussian(q_a,p_mu,p_log_var)
    # kl divergence
    kl = qz - pz
    
    return kl

def ELBO_loss(sample_img, outputs, kl_warmup=None):
    # Parameter in deterministic warmup for KL divergence
    beta = 1 if kl_warmup is None else kl_warmup
    
    # Weighting kl's and likelihood
    w1 = 0.5
    w2 = 0.5
    
    if type(outputs['x_mean']) == list:
        ELBO = []
        kl_x = -0.5 * torch.sum(1 + outputs['log_var'] - outputs['mu']**2 - torch.exp(outputs['log_var']), dim=1)
        kl_x = torch.sum(kl_x,dim=1) # sum over the features
        for j in range(No_classes):
            likelihood = Gaussian_density(sample_img, outputs['x_mean'][j], outputs['x_log_var'][j])
            if aux_variables > 0:
                kl_a = kl_a_calc(outputs["q_a"],outputs["q_a_mu"],outputs["q_a_log_var"],outputs["p_a_mu"][j],outputs["p_a_log_var"][j])
                kl = w1 * kl_x + (1 - w1) * kl_a
            else:
                kl_a = torch.Tensor([0])
                kl = kl_x 
            likelihood = likelihood.view(batch_size, -1)
            likelihood = torch.sum(likelihood, dim=1) # Sum over features (224x224 = 50,176)
            ELBO.append(w2 * likelihood - (1 - w2) * beta * kl)
        
        L = torch.cat( (torch.unsqueeze(ELBO[0],1),torch.unsqueeze(ELBO[1],1)),dim =1 )
        # Calculate entropy H(q(y|x)) and sum over all labels
        logits = torch.mean(outputs['y_hat'],dim = 1)
        
        H = -torch.sum(torch.mul(logits, torch.log(logits + 1e-8)), dim=-1) 
        L = torch.sum(torch.mul(logits, L), dim=-1)
        
        # Equivalent to -U(x)
        U = L - H
        
        #RMS_1 = torch.sqrt(torch.mean((sample_img - outputs['x_mean'][0])**2))
        #RMS_2 = torch.sqrt(torch.mean((sample_img - outputs['x_mean'][1])**2))
        #U = - torch.sum(RMS_1 + RMS_2)
        #U = - ( torch.abs(sample_img - outputs['x_mean'][0])**2 + torch.abs(sample_img - outputs['x_mean'][0])**2)
        
        return -torch.mean(U), -torch.mean(H), -torch.mean(L),  (1 - w2) * beta * torch.mean(kl), -w2 * torch.mean(likelihood), w1*torch.mean(kl_x), (1-w1)*torch.mean(kl_a)
    else:
        likelihood = Gaussian_density(sample_img, outputs['x_mean'], outputs['x_log_var'])
        kl_x = -0.5 * torch.sum(1 + outputs['log_var'] - outputs['mu']**2 - torch.exp(outputs['log_var']), dim=1)
        if aux_variables > 0:
            kl_a = kl_a_calc(outputs["q_a"],outputs["q_a_mu"],outputs["q_a_log_var"],outputs["p_a_mu"],outputs["p_a_log_var"])
            kl = w1 * torch.mean(kl_x) + (1 - w1) * torch.mean(kl_a)
        else:
            kl_a = torch.Tensor([0])
            kl = torch.mean(kl_x)
        likelihood = likelihood.view(batch_size, -1)
        likelihood = torch.sum(likelihood, dim=1) # Sum over features (224x224 = 50,176)
        ELBO = w2 * torch.mean(likelihood) - (1 - w2) * beta * kl    
        # Notice minus sign as we want to maximise ELBO
        return -ELBO, (1 - w2) * beta * kl, -w2 * torch.mean(likelihood), w1*torch.mean(kl_x), (1-w1)*torch.mean(kl_a)
    
    # Regularization error: 
    # Kulback-Leibler divergence between approximate posterior, q(z|x)
    # and prior p(z) = N(z | mu, sigma*I).
    
    # In the case of the KL-divergence between diagonal covariance Gaussian and 
    # a standard Gaussian, an analytic solution exists. Using this excerts a lower
    # variance estimator of KL(q||p)
    # Combining the two terms in the evidence lower bound objective (ELBO) 
    # mean over batch 

# Define optimizer: The Adam optimizer works really well with VAEs.
optimizer = optim.Adam(net.parameters(), lr=0.001)
loss_function = ELBO_loss

In [None]:
#### Test loss functions ####

from torch.autograd import Variable
import gc
gc.collect()
torch.cuda.empty_cache

x, y = next(iter(train_loader_labelled))
u, _ = next(iter(train_loader))

y_hot =  torch.zeros([batch_size,2], requires_grad=True)
for i in range(len(y)):
    y_hot[i] = torch.tensor([0, 1]) if y[i]==1 else torch.tensor([1, 0])

x, y, y_hot = Variable(x), Variable(y), Variable(y_hot)
if cuda:
    # They need to be on the same device and be synchronized.
    x, y, y_hot = x.cuda(device=0), y.cuda(device=0), y_hot.cuda(device=0)
    u = u.cuda(device=0)

outputs = net(u)    
#loss, kl, likelihood = loss_function(u,outputs)
elbo_u, elbo_H, elbo_L, kl_u, likelihood_u, kl_u_x, kl_u_a= loss_function(u,outputs)
outputs = net(x,y_hot)

x_hat = outputs["x_hat"]
mu, log_var = outputs["mu"], outputs["log_var"]
mu_img, log_var_img = outputs["x_mean"], outputs["x_log_var"]
z = outputs["z"]
logits = outputs["y_hat"]
y_hot = y_hot.unsqueeze(dim = 1).repeat(1,logits.shape[1],1)
classification_loss = torch.sum(torch.abs(y_hot - logits))

#loss, kl = loss_function(x, mu_img, log_var_img, torch.sum(mu,dim = 1), torch.sum(log_var,dim = 1))
if 1 == 0:
    loss, kl, likelihood, kl_l_x, kl_l_a = loss_function(u,outputs)
    print('mu:          ',mu.shape,torch.sum(torch.isnan(mu)))
    print('log_var:     ',log_var.shape,torch.sum(torch.isnan(log_var)))
    print('mu_img:      ',mu_img.shape,torch.sum(torch.isnan(mu_img)))
    print('log_var_img: ',log_var_img.shape,torch.sum(torch.isnan(log_var_img)))
    print('x:           ',x.shape,torch.sum(torch.isnan(x)))
    print('x_hat:       ',x_hat.shape,torch.sum(torch.isnan(x_hat)))
    print('z:           ',z.shape,torch.sum(torch.isnan(z)))
    print('Total loss:  ',loss)
    print('kl:          ',kl)
    print('Class. loss: ',classification_loss)
    print('Likelihood:  ',likelihood)