# VAEGAN NEGATIVE BINOMIAL

> The VAEGAN model, takes an Variational Encoder, Variational Decoder and Classifier model as inputs. Uses a negative binomial as the latent variable rather than Gausian.

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| default_exp Models.VAEGAN_NEG_BI

In [None]:
#| export
import torch 
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import scvi
from scvi.models.distributions import NegativeBinomial
from torch.distributions import Normal

class VAEGAN_NEG_BI(nn.Module):
    def __init__(self, encoder, decoder_r, decoder_p, classifier, log = True):
        """
        The VAEGAN model with Negative Binomial distribution as Latent Variable
        """
        super(VAEGAN_NEG_BI, self).__init__()
        self.encoder = encoder 
        self.decoder_r = decoder_r
        self.decoder_p = decoder_p
        
        self.classifier = classifier
    
        
    def reparameterize(self, mu, logvar):
        var = logvar.exp() + 1e-8
        z = Normal(mu, var.sqrt()).rsample()
        return z
    
    def decode(self, z):
        h_r = self.decoder_r(z) 
        h_p = self.decoder_p(z)
        h_r = F.sigmoid(h_r)
        h_p = F.softmax(h_p) 
        x_hat = NegativeBinomial(total_count = h_r, probs= h_p).sample()
        return x_hat, h_r, h_p


    def forward(self, x, log=False):
        if log :
            x= torch.log(x+1)
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_hat, h_r, h_p = self.decode(z)
        y_hat = self.classifier(z) 
        return x_hat, y_hat, mu, logvar, h_r, h_p



Global seed set to 0
  new_rank_zero_deprecation(
  return new_rank_zero_deprecation(*args, **kwargs)


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

In [None]:
def NLL_loss(data, h_r, h_p):
    
    data = data.view(-1, 1)
    ll = torch.distributions.negative_binomial.NegativeBinomial(h_r, h_p)

    neg_ll = -torch.mean(torch.sum(ll, dim=-1))
    return neg_ll