In [3]:
from flax import linen as nn 
from flax.linen import initializers
from flax.core import freeze, unfreeze
from flax.linen.module import compact
from basic import logprior_fn, logvariational_fn, samplevariational_fn, sigmas_from_rhos
from typing import (Any, Callable, Iterable, List, Optional, Sequence, Tuple,
                    Union)
import jax.numpy as jnp
import jax.random as random
from jax import lax
PRNGKey = Any
Shape = Tuple[int, ...]
Dtype = Any  # this could be a real type?
Array = Any
default_kernel_init = initializers.lecun_normal()

In [44]:
class BNNLayer(nn.Module):

    features: int
    # prior_pi: float
    # prior_var1: float
    # prior_var2: float
    parameter_init: Callable = nn.initializers.lecun_normal()

    @compact
    def __call__(self, sampling_key, inputs, n_samples: int):
        mus = self.param(
            "mus", 
            self.parameter_init, 
            (jnp.shape(inputs)[-1] + 1, self.features))
        
        rhos = self.param(
            "rhos",
            self.parameter_init,
            (jnp.shape(inputs)[-1] + 1, self.features)
        )

        weights = samplevariational_fn(
            mus=mus,
            rhos=rhos,
            # KEY IS CONSTANT RIGHT NOW
            key=sampling_key,
            n_samples=n_samples
        )

        # Augment inputs by adding a column of 1s so that
        # biases don't need to be separately created.
        column_of_ones = jnp.ones((jnp.shape(inputs)[0], 1))
        inputs_augmented = jnp.concatenate((column_of_ones, inputs), axis=-1)

        y = jnp.dot(inputs_augmented, weights)

        return y



In [52]:
key1, key2, key3 = random.split(random.PRNGKey(0), 3)

x = random.uniform(key1, (4,4))
model = BNNLayer(features=3)
params = model.init(
    rngs=key2, 
    sampling_key=key3,  
    inputs=x, 
    n_samples=1)
print(params)
y = model.apply(params, key3, x, 1)
y


FrozenDict({
    params: {
        mus: Array([[-0.02930064,  0.22184654, -0.8278271 ],
               [ 0.843966  ,  0.0176126 ,  0.83982617],
               [ 0.6925206 ,  0.13406326,  0.6939206 ],
               [-0.6244742 , -0.9194189 ,  0.11918769],
               [-0.21873963,  0.2695162 ,  0.4354371 ]], dtype=float32),
        rhos: Array([[-0.2408339 ,  0.10038533,  0.661483  ],
               [-1.001577  ,  0.18277498, -0.79746354],
               [ 0.27618888, -0.06239573,  0.05086489],
               [-0.27124298,  0.09979365, -0.6051223 ],
               [ 0.24168347,  0.08925466, -0.7675577 ]], dtype=float32),
    },
})


Array([[ 1.3309863 ,  2.221767  , -0.27026787],
       [ 1.2240943 ,  2.5699005 , -0.3441275 ],
       [ 1.5596361 ,  2.3207808 ,  0.12029155],
       [ 1.3289882 ,  1.7428572 , -0.19145672]], dtype=float32)

In [48]:
class SimpleDense(nn.Module):
  features: int
  kernel_init: Callable = nn.initializers.lecun_normal()
  bias_init: Callable = nn.initializers.zeros_init()

  @nn.compact
  def __call__(self, inputs):
    kernel = self.param('kernel',
                        self.kernel_init, # Initialization function
                        (inputs.shape[-1], self.features))  # shape info.
    
    y = lax.dot_general(inputs, kernel,
                        (((inputs.ndim - 1,), (0,)), ((), ())),) # TODO Why not jnp.dot?
    bias = self.param('bias', self.bias_init, (self.features,))
    y = y + bias
    return y

key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))

model = SimpleDense(features=3)
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameters:\n', params)
print('output:\n', y)

initialized parameters:
 FrozenDict({
    params: {
        kernel: Array([[ 0.61506   , -0.22728713,  0.6054702 ],
               [-0.29617992,  1.1232013 , -0.879759  ],
               [-0.35162622,  0.3806491 ,  0.6893246 ],
               [-0.1151355 ,  0.04567898, -1.091212  ]], dtype=float32),
    },
})
output:
 [[ 0.61506    -0.22728713  0.6054702 ]
 [-0.29617992  1.1232013  -0.879759  ]
 [-0.35162622  0.3806491   0.6893246 ]
 [-0.1151355   0.04567898 -1.091212  ]]
