In [1]:
from flax import linen as nn 
from flax.linen import initializers
from flax.core import freeze, unfreeze
from flax.linen.module import compact
from basic import logprior_fn, logvariational_fn, samplevariational_fn, sigmas_from_rhos
from typing import (Any, Callable, Iterable, List, Optional, Sequence, Tuple,
                    Union)
import jax
import jax.numpy as jnp
import jax.random as random
from jax import lax
PRNGKey = Any
Shape = Tuple[int, ...]
Dtype = Any  # this could be a real type?
Array = Any
default_kernel_init = initializers.lecun_normal()

In [4]:
class BNNLayer(nn.Module):

    features: int
    prior_pi: float
    prior_var1: float
    prior_var2: float
    parameter_init: Callable = nn.initializers.lecun_normal()

    @compact
    def __call__(self, sampling_key: random.PRNGKey, inputs, n_samples: int):
        # Variational Parameters
        mus = self.param(
            "mus", 
            self.parameter_init, 
            (jnp.shape(inputs)[-1] + 1, self.features))
        rhos = self.param(
            "rhos",
            self.parameter_init,
            (jnp.shape(inputs)[-1] + 1, self.features)
        )

        # Sample weights
        weights = samplevariational_fn(
            mus=mus,
            rhos=rhos,
            key=sampling_key,
            n_samples=n_samples
        )

        # Augment inputs by adding a column of 1s so that
        # biases don't need to be separately created.
        column_of_ones = jnp.ones((jnp.shape(inputs)[0], 1))
        inputs_augmented = jnp.concatenate((column_of_ones, inputs), axis=-1)
        y = jnp.dot(inputs_augmented, weights)

        # Compute terms for KL penalty
        # Need to copy and stack due to multiple samples for weights.
        stacked_mus = jnp.stack((mus,) * n_samples, axis=0)
        stacked_rhos = jnp.stack((rhos,) * n_samples, axis=0)
        log_variational_density = logvariational_fn(
            weights=weights,
            mus=stacked_mus,
            rhos=stacked_rhos,
        )
        log_prior_density = logprior_fn(
            weights=weights,
            pi=self.prior_pi,
            var1=self.prior_var1,
            var2=self.prior_var2,
        )
        return y, log_variational_density, log_prior_density



In [9]:
key1, key2, key3 = random.split(random.PRNGKey(0), 3)

x = random.uniform(key1, (4,4))
model = BNNLayer(features=3, prior_pi=0.5, prior_var1=0.9, prior_var2=0.001)
params = model.init(
    rngs=key2, 
    sampling_key=key3,  
    inputs=x, 
    n_samples=2)
print(params)
y = model.apply(params, key3, x, 2)
y

FrozenDict({
    params: {
        mus: Array([[-0.02930064,  0.22184654, -0.8278271 ],
               [ 0.843966  ,  0.0176126 ,  0.83982617],
               [ 0.6925206 ,  0.13406326,  0.6939206 ],
               [-0.6244742 , -0.9194189 ,  0.11918769],
               [-0.21873963,  0.2695162 ,  0.4354371 ]], dtype=float32),
        rhos: Array([[-0.2408339 ,  0.10038533,  0.661483  ],
               [-1.001577  ,  0.18277498, -0.79746354],
               [ 0.27618888, -0.06239573,  0.05086489],
               [-0.27124298,  0.09979365, -0.6051223 ],
               [ 0.24168347,  0.08925466, -0.7675577 ]], dtype=float32),
    },
})


(Array([[[ 1.3309863 ,  2.221767  , -0.27026787],
         [ 1.3309863 ,  2.221767  , -0.27026787]],
 
        [[ 1.2240943 ,  2.5699005 , -0.3441275 ],
         [ 1.2240943 ,  2.5699005 , -0.3441275 ]],
 
        [[ 1.5596361 ,  2.3207808 ,  0.12029158],
         [ 1.5596361 ,  2.3207808 ,  0.12029158]],
 
        [[ 1.3289882 ,  1.7428572 , -0.19145672],
         [ 1.3289882 ,  1.7428572 , -0.19145672]]], dtype=float32),
 Array([-2.1906397 , -2.877598  , -0.04729718, -1.1645566 , -0.6022583 ,
        -1.1752043 , -1.1737965 , -1.2948349 , -1.0184361 , -1.7649992 ,
        -5.522732  , -0.6626189 , -0.3937053 , -1.6940267 , -0.90627396,
        -2.1906397 , -2.877598  , -0.04729718, -1.1645566 , -0.6022583 ,
        -1.1752043 , -1.1737965 , -1.2948349 , -1.0184361 , -1.7649992 ,
        -5.522732  , -0.6626189 , -0.3937053 , -1.6940267 , -0.90627396],      dtype=float32),
 Array([-2.36871   , -3.584829  , -2.1393044 , -1.6579031 , -1.5836654 ,
        -1.6377769 , -0.23071143, -1.858

In [26]:
# class BNN(nn.Module):

#     features: Sequence[int]
#     prior_pi: float
#     prior_var1: float
#     prior_var2: float
#     # parameter_init: Callable = nn.initializers.lecun_normal()

#     def setup(self):
#         # Right now same prior hyperparameters for each layers
#         self.layers = [
#             BNNLayer(
#             features_, 
#             self.prior_pi, 
#             self.prior_var1, 
#             self.prior_var2
#             ) for features_ in self.features
#         ]

#     def __call__(self, sampling_key: random.PRNGKey, inputs, n_samples: int):
#         log_variational_densities = list()
#         log_prior_densities = list()
#         x = inputs
#         for i, lyr in enumerate(self.layers):
            
#             x, log_variational_density, log_prior_density = lyr(sampling_key, x, n_samples)
#             log_variational_densities.append(log_variational_density)
#             log_prior_densities.append(log_prior_density)
        
#             if i != len(self.layers) - 1:
#                 x = nn.relu(x)

#         return x, log_variational_densities, log_prior_densities