In [1]:
import torch
from torch import Tensor
from jaxtyping import Float
import scipy
import scipy.stats as stats
import numpy as np

In [55]:
def sigmas_from_rhos(rhos: Float[Tensor, "N D"]) -> Float[Tensor, "N D"]:
    return torch.log(1 + torch.exp(rhos))


def logvariational_fn(
    weights: Float[Tensor, "N D"],
    mus: Float[Tensor, "N D"],
    sigmas: Float[Tensor, "N D"],
) -> Float[Tensor, "*N 1"]:
    """
    Computes the log density of a diagonal multivariate gaussian distribution given
    the parameters governing the distribution (mean_vetors, sigmas) and an
    input weights.

    Parameters
    ----------
        weights : N x D tensor
            N D-dimensional vectors representing the weights of the network to be
            evaluated.
        mus : N x D tensor
            N D-dimensional vectors representing the mean vectors for the
            N independent multivariate gaussian distributions to evaluate
            weights under.
        sigmas : N x D tensor
            N D-dimensional vectors representing the diagonal entries of the covariance
            matrix for N independent multivariate gaussian distributions to evaluate
            weights under.

    Returns
    -------
    N x 1 tensor
        the log density for each N weight vectors according to the parameters in
        mus and sigmas.
    """
    # If a 0 is in sigmas -> non-singular since sigmas are diagonals
    if not sigmas.all():
        raise ValueError(f"sigmas need to all be positive, but they are {sigmas}")
    # this is from Daniel W https://stackoverflow.com/questions/48686934/numpy-vectorization-of-multivariate-normal
    D = weights.size(1)
    constant = D * np.log(2 * torch.pi)
    log_determinants = torch.log(torch.prod(sigmas, axis=1)) 
    deviations = weights - mus
    inverses = 1 / sigmas
    return -0.5 * (constant + log_determinants +
        torch.sum(deviations * inverses * deviations, axis=1))



def samplevariational_fn(
    n_samples: int,
    mus: Float[Tensor, "N D"],
    sigmas: Float[Tensor, "N D"],
    epsilons: Float[Tensor, "N D"],
) -> Float[Tensor, "N D"]:
    """
    Samples from a diagonal multivariate gaussian governed by the mus
    and sigmas.

    Parameters
    ----------
        n_samples : int
            number of samples to return.
        mus : N x D tensor
            N D-dimensional vectors representing the mean vectors for N independent
            multivariate gaussian distributions.
        sigmas : N x D tensor
            N D-dimensional vectors representing the diagonal entries of the covariance
            matrix for N independent multivariate gaussian distributions.

    Returns
    -------
    n_samples x N x D tensor
        n_samples of the N x D tensor that represents N samples from N independent
        D-dimensional multivariate gaussian distributions.
    """
    samples = []
    for _ in range(n_samples):
        samples.append(mus + sigmas * epsilons)
    return torch.stack(samples)


def logprior_fn(weights: Float[Tensor, "N D"], pi: float, sigma1: float, sigma2: float):
    print(weights.shape)

    gaussian1_log_prob = stats.multivariate_normal.logpdf(
        x=weights, mean=0, cov=sigma1**2
    )
    gaussian2_log_prob = stats.multivariate_normal.logpdf(
        x=weights, mean=0, cov=sigma2**2
    )

    return torch.log(
        pi * torch.exp(gaussian1_log_prob) + (1 - pi) * torch.exp(gaussian2_log_prob)
    )


In [56]:
weights = torch.ones((2,3))
print(weights)

tensor([[1., 1., 1.],
        [1., 1., 1.]])


In [57]:
mus = torch.ones((2,3)) * 3
mus

tensor([[3., 3., 3.],
        [3., 3., 3.]])

In [61]:
sigmas = torch.ones((2,3))
sigmas = torch.arange(1,7).reshape(2,3)
sigmas

tensor([[1, 2, 3],
        [4, 5, 6]])

In [62]:
logvariational_fn(weights, mus, sigmas)

tensor([-7.3194, -6.3839])

In [63]:
weights[0], mus[0], sigmas[0]

(tensor([1., 1., 1.]), tensor([3., 3., 3.]), tensor([1, 2, 3]))

In [64]:
stats.multivariate_normal.logpdf(weights[0], mus[0], sigmas[0])

-7.319362000894712

In [65]:
stats.multivariate_normal.logpdf(weights[1], mus[1], sigmas[1])

-6.383894804338374