In [45]:
import abc
import warnings
import numpy as np
import torch
import os
import typing
import numpy as np
import torch
import torch.optim
from matplotlib import pyplot as plt
from sklearn.metrics import roc_auc_score, average_precision_score
from torch import nn
from torch.nn import functional as F
from tqdm import trange
from torch.distributions import Normal
from torch import distributions as dis

class ParameterDistribution(torch.nn.Module, metaclass=abc.ABCMeta):
    """
    Abstract class that models a distribution over model parameters,
    usable for Bayes by backprop.
    You can implement this class using any distribution you want
    and try out different priors and variational posteriors.
    All torch.nn.Parameter that you add in the __init__ method of this class
    will automatically be registered and know to PyTorch.
    """

    def __init__(self):
        super().__init__()

    @abc.abstractmethod
    def log_likelihood(self, values: torch.Tensor) -> torch.Tensor:
        """
        Calculate the log-likelihood of the given values
        :param values: Values to calculate the log-likelihood on
        :return: Log-likelihood
        """
        pass

    @abc.abstractmethod
    def sample(self) -> torch.Tensor:
        """
        Sample from this distribution.
        Note that you only need to implement this method for variational posteriors, not priors.

        :return: Sample from this distribution. The sample shape depends on your semantics.
        """
        pass

    def forward(self, values: torch.Tensor) -> torch.Tensor:
        # DO NOT USE THIS METHOD
        # We only implement it since torch.nn.Module requires a forward method
        warnings.warn('ParameterDistribution should not be called! Use its explicit methods!')
        return self.log_likelihood(values)

In [None]:
    # Implemented by Dean
class UniveriateGaussianPrior(ParameterDistribution):
    """
    Univeriate Guassian distribution 
    """
    def __init__(self, mu, sigma):
        super(UniveriateGaussianPrior, self).__init__()
        self.mu = mu
        self.sigma = sigma # sigma is the standard deviation here, not the variance!
    
    def log_likelihood(self, values: torch.Tensor) -> torch.Tensor:
        dist = Normal(loc=self.mu, scale=self.sigma)
        log_likelihood = dist.log_prob(values).sum()
        return log_likelihood

    def sample(self):
       return Normal(loc=self.mu, scale=self.sigma).sample()

In [None]:
class MultivariateDiagonalGaussian(ParameterDistribution):
    """
    Multivariate diagonal Gaussian distribution,
    i.e., assumes all elements to be independent Gaussians
    but with different means and standard deviations.
    This parameterizes the standard deviation via a parameter rho as
    sigma = softplus(rho).
    """

    def __init__(self, mu: torch.Tensor, rho: torch.Tensor):
        super(MultivariateDiagonalGaussian, self).__init__()  # always make sure to include the super-class init call!
        assert mu.size() == rho.size()
        self.mu = mu
        self.rho = rho
        self.sig = (F.softplus(rho)*0.05 + 1e-5).detach()

    def log_likelihood(self, values: torch.Tensor) -> torch.Tensor:

        likelihood = Normal(self.mu, self.sig).log_prob(values).sum()

        return likelihood

    def sample(self) -> torch.Tensor:
        epsilon = torch.distributions.Normal(0,1).sample(self.rho.size())
        return self.mu + self.sig*epsilon


In [46]:
class GaussianMixturePrior(ParameterDistribution):
    """
    Mixture of two Gaussian distributions as described in Bludell et al., 2015.
    """
    def __init__(self, mu_0: torch.Tensor, sigma_0: torch.Tensor, mu_1: torch.Tensor, sigma_1: torch.Tensor, pi: torch.Tensor):
        super(GaussianMixturePrior, self).__init__()  # always make sure to include the super-class init call!
        self.mu_0 = mu_0 # mean of distribution 0
        self.sigma_0 = sigma_0 # std of distrinution 0
        self.mu_1 = mu_1 # mean of distribution 1
        self.sigma_1 = sigma_1 # std of distribution 1
        self.pi = pi # Probabilistic weight

    def log_likelihood(self, values: torch.Tensor) -> torch.Tensor:
        dist_0 = Normal(loc=self.mu_0, scale=self.sigma_0)
        dist_1 = Normal(loc=self.mu_1, scale=self.sigma_1)
        ll_0 = dist_0.log_prob(values)
        ll_1 = dist_1.log_prob(values)
        return torch.log(self.pi * torch.exp(ll_0) + (1 - self.pi) * torch.exp(ll_1)).sum()

    def sample(self) -> torch.Tensor:
        if np.random.rand() < self.pi:
            return Normal(loc=self.mu_0, scale=self.sigma_0).sample()
        else:
            return Normal(loc=self.mu_1, scale=self.sigma_1).sample()