In [2]:
import torch
import torch.nn.functional as F
from basic import (log_variational_per_scalar, 
                         log_variational_per_vector, 
                         sample_variational_scalars, 
                         sample_variational_vectors,
                         log_prior_per_scalar, 
                         log_prior_per_vector,)

### `log_variational_per_scalar`

In [3]:
weights = torch.randn((25, 34)) 
mus = torch.randn((25, 34)) + 3
rhos = torch.randn((25, 34))

weight_matrix_flattened = weights.ravel()
mus_flattened = mus.ravel()
rhos_flattened = rhos.ravel()

assert log_variational_per_scalar(weight_matrix_flattened, mus_flattened, rhos_flattened).shape[0], 25 * 34

### `log_variational_per_vector`

In [4]:
# 5 instances of a 13-dimensional diagonal multivariate normal distribution
weight_vectors = torch.randn((13, 5)) 
mu_vectors = torch.randn((13, 5)) + 3
rho_vectors = torch.randn((13, 5))

assert log_variational_per_vector(weight_vectors, mu_vectors, rho_vectors).shape[0] == 13

In [4]:
class BayesLinear:

    def __init__(
            self, 
            in_features: int, 
            out_features: int, 
            prior_pi: float,
            prior_sigma1: float,
            prior_sigma2: float, 
            init_params: torch.Tensor = None
            ):
        
        if init_params:
            self.mu_vectors = init_params
            self.rho_vectors = init_params
        else:
            self.mu_vectors = torch.empty(size=(in_features, out_features)).normal_()
            self.rho_vectors = torch.empty(size=(in_features, out_features)).normal_()
        
        self.prior_pi = prior_pi
        self.prior_sigma1 = prior_sigma1
        self.prior_sigma2 = prior_sigma2

    def __call__(self, x: torch.Tensor, n_samples: int = 1):

        sampled_weight_vectors = sample_variational_vectors(
            n_samples=n_samples,
            mu_vectors=mu_vectors,
            rho_vectors=rho_vectors
        )

        mean_linear_out = torch.stack(
            [F.linear(x, sampled_weight_vectors[i].T) for i in range(n_samples)]
        ).mean()

        return mean_linear_out
    
        
        

In [5]:
b = BayesLinear(5, 10, {1,2,3}, )

TypeError: BayesLinear.__init__() missing 2 required positional arguments: 'prior_sigma1' and 'prior_sigma2'

In [6]:
def bayes_linear(
            in_features, 
            out_features, 
            x, 
            prior_pi: float,
            prior_sigma1: float,
            prior_sigma2: float, 
            n_samples=1
        ):

    mu_vectors = torch.empty(size=(in_features, out_features)).normal_()
    rho_vectors = torch.empty(size=(in_features, out_features)).normal_()

    sampled_weight_vectors = sample_variational_vectors(
        n_samples=n_samples,
        mu_vectors=mu_vectors,
        rho_vectors=rho_vectors
        )
    
    mean_linear = torch.stack(
        [F.linear(x, sampled_weight_vectors[i].T) for i in range(n_samples)]
        ).mean()
    
    # mean_logprior = [logprior_fn(
    #     sampled_weight_vectors[i], 
    #     pi=prior_pi, 
    #     sigma1=prior_sigma1, 
    #     sigma2=prior_sigma2) for i in range(n_samples)].mean()
    
    return mean_linear
    
    # bias_vector = torch.empty(size=(out_features,)).normal_()




In [7]:
x = torch.randn((1, 5), dtype=torch.float)
bayes_linear(
    5, 
    10, 
    x, 
    n_samples=1,
    prior_pi = 0.5,
    prior_sigma1=0.7,
    prior_sigma2=0.005)

tensor(0.1453)

In [9]:
weight_vector_samples = torch.randn(2, 3, 4)
x = torch.randn((1,3))


weight_vector_samples[0].shape

log_prior_per_vector(weight_vector_samples[0], 0.5, 0.9, 0.001)

tensor([-6.0180, -9.5373, -8.2781], dtype=torch.float64)

In [11]:
log_variational_per_vector(weight_vector_samples[0], torch.randn((3,4)), torch.randn((3,4)))

tensor([-82.1949, -81.2742, -47.7652])

In [None]:
torch.stack([F.linear(x, weight_vector_samples[i].T) for i in range(2)]).mean()

tensor(-0.3835)

In [None]:
import scipy.stats as stats

stats.multivariate_normal