In [1]:
from bayesian_nn import sigmas_from_rhos, logvariational_fn, samplevariational_fn, rhos_from_sigmas
import numpy as np
import scipy.stats as stats
import pytest
import torch

In [2]:

pytest.mark.parametrize("weights mus sigmas",
                        [
    (torch.ones((3,2)), torch.ones((3,2)) * 2, torch.arange(1,7).reshape(3,2)),
    (torch.ones((2,)), torch.ones((2,)) * 2, torch.arange(2,4).reshape(2,))
                        ])
def test_logvariational_fn(weights, mus, sigmas):
    results = logvariational_fn(weights, mus, sigmas)
    for i, result in enumerate(results):
        ground_truth = stats.multivariate_normal.logpdf(weights[i], mus[i], sigmas[i])
        assert np.allclose(result, ground_truth)



In [6]:
weights, mus, rhos = (torch.ones((2,2)), torch.ones((2,2)) * 2, torch.ones((2,2)))

In [7]:
weights.shape, mus.shape, rhos.shape

(torch.Size([2, 2]), torch.Size([2, 2]), torch.Size([2, 2]))

In [8]:
logvariational_fn(weights, mus, rhos)

tensor([-2.9627, -2.9627])

In [27]:
def rhos_from_sigmas(sigmas):
    return torch.log(torch.exp(sigmas) - 1)

In [31]:
print(rhos)
sigmas = sigmas_from_rhos(rhos)
print(sigmas)
rhos = rhos_from_sigmas(sigmas)
print(rhos)

tensor([[1., 1.],
        [1., 1.]])
tensor([[1.3133, 1.3133],
        [1.3133, 1.3133]])
tensor([[1.0000, 1.0000],
        [1.0000, 1.0000]])


In [11]:
stats.multivariate_normal.logpdf(weights[0], mus[0], variances[0])

-2.9627304762584226

In [None]:
n_samples = 10
mus = torch.ones((2,3))
sigmas = torch.ones((2,3))
epsilons = 

samplevariational_fn(n_samples, mus, sigmas, epsilons)

In [4]:
import time

In [27]:
def test_vectorized_gaussian_logpdf():
    n = 128**2
    d = 4

    means = torch.FloatTensor(n,d).uniform_(-1, 1)
    covariances = torch.FloatTensor(n,d).uniform_(0, 2)
    rhos = rhos_from_sigmas(covariances.sqrt())
    X = torch.FloatTensor(n,d).uniform_(-1, 1)

    refs = []

    ref_start = time.time()
    for x, mean, covariance in zip(X, means, covariances):
        refs.append(stats.multivariate_normal.logpdf(x, mean, covariance))
    ref_time = time.time() - ref_start

    fast_start = time.time()
    results = logvariational_fn(X, means, rhos)
    fast_time = time.time() - fast_start

    print("Reference time:", ref_time)
    print("Vectorized time:", fast_time)
    print("Speedup:", ref_time / fast_time)

    refs = np.array(refs)

    print(results)
    print(refs)
    assert np.allclose(results, refs, atol=1e-2)

In [28]:
test_vectorized_gaussian_logpdf()

Reference time: 0.7334489822387695
Vectorized time: 0.000988006591796875
Speedup: 742.3523166023166
tensor([-4.5360, -7.6813, -9.2519,  ..., -4.6330, -3.9440, -7.1242])
[-4.53596129 -7.68133213 -9.25190601 ... -4.63296877 -3.94398697
 -7.12425061]
