In [1]:
import torch
import torch.nn.functional as F
from basic import logvariational_fn, samplevariational_fn, logprior_fn
import torch.nn as nn

In a Bayesian layer we have a matrix 

$\mathbf{W}$ that is of shape `in_features x out_features`. Each weight $w$ comes from the reparameterized variational Gaussian governed by a deterministic function wrt the parameter $\mu$ and $\rho$. So matrices of `in_features x out_features` describe the weight matrix.


Pseudocode:

- initialize mus, rhos, -- these are the parameters we will change in the optimization. 
- in a forward pass, we sample the weights given these mus and rhos and use that to do a linear transformation on the input. 
- meanwhile, we have to calculate the log likelihood wrt to the variational distribution(s) and the log likelihood wrt to the prior over the weights.

In [126]:
class BayesLinear(nn.Module):

    """
    Defines a Bayesian Linear Layer.

    Attributes
    ----------
        in_features : int
            number of features in the input 
    """

    def __init__(
            self,
            in_features: int,
            out_features: int,
            prior_pi: float,
            prior_var1: float,
            prior_var2: float,
    ):
        super().__init__()
        # Layer attributes
        self.in_features = in_features
        self.out_features = out_features
        
        # Parameters governing weights of the layer
        # We add a row for biases 
        self.mus = nn.Parameter(torch.empty(size=(in_features + 1, out_features)).normal_())
        self.rhos = nn.Parameter(torch.empty(size=(in_features + 1, out_features)).normal_()) 

        # Parameters governing weight's prior distribution 
        self.pi = prior_pi 
        self.prior_var1 = prior_var1
        self.prior_var2 = prior_var2

    def __call__(self, x, n_samples):
        # For biases
        column_of_ones = torch.ones(x.size(0), 1)
        x_aug = torch.concat([column_of_ones, x], axis=-1)

        sampled_weights = samplevariational_fn(
            mus=self.mus,
            rhos=self.rhos
            )
        
        # logvariational = logvariational_fn(sampled_weights, self.mus, self.rhos)
        # logprior = logprior_fn(sampled_weights, self.pi, self.var1, self.var2)
        
        return x_aug @ sampled_weights

        

In [127]:
bll = BayesLinear(3, 5, .5, .8, .0025)

In [128]:
x = torch.randn((2,3))
x_aug = torch.concat([torch.ones((x.size(0), 1)), x], axis=-1)
x, x_aug

(tensor([[ 0.5816,  0.5264,  1.0082],
         [-0.6498,  0.4660,  2.0500]]),
 tensor([[ 1.0000,  0.5816,  0.5264,  1.0082],
         [ 1.0000, -0.6498,  0.4660,  2.0500]]))

In [129]:
bll(x, 5)

tensor([[-0.7597, -1.2807, -3.2198,  0.8431,  0.1781],
        [-2.9015,  7.0670,  1.3891,  0.4946, -1.1932]], grad_fn=<MmBackward0>)

In [130]:
x_aug = torch.concat([torch.ones((x.size(0), 1)), x], axis=-1)
W_aug = 


SyntaxError: invalid syntax (3111630757.py, line 2)

In [31]:
def bayes_linear(
            in_features: int, 
            out_features: int, 
            x, 
            prior_pi: float,
            prior_var1: float,
            prior_var2: float, 
        ):

    mus = nn.Parameter(torch.empty(size=(in_features, out_features)).normal_())
    rhos = nn.Parameter(torch.empty(size=(in_features, out_features)).normal_())

    sampled_weights = samplevariational_fn(
        mus=mus,
        rhos=rhos
        )
    
    logvariational_prob = logvariational_fn(sampled_weights, mus, rhos)
    logprior_prob = logprior_fn(sampled_weights, prior_pi, prior_var1, prior_var2)

    return F.linear(x, sampled_weights.T), logvariational_prob.sum(), logprior_prob.sum()





In [5]:
x = torch.randn((1, 5), dtype=torch.float)
bayes_linear(
    in_features=5, 
    out_features=10, 
    x=x, 
    prior_pi = 0.5,
    prior_var1=0.7,
    prior_var2=0.005)

RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.

In [19]:
W = torch.randn((3,5))
W_aug = torch.concat([torch.ones(3,1), W], axis=-1)
W_aug

tensor([[ 1.0000, -0.5345, -1.6729, -0.8928, -2.5430, -0.2705],
        [ 1.0000, -0.7756,  1.6903,  0.6987, -1.0473, -0.1190],
        [ 1.0000,  0.5254, -1.2813,  0.6321,  0.2138, -0.9132]])