In [25]:
import tensorflow as tf
from tensorflow.keras.regularizers import Regularizer
from tensorflow.keras import layers as kl
from tensorflow import keras
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd

from typing import Callable

In [None]:
tfd.Distribution

In [6]:
class LambdaRegularizer(Regularizer):
    def __init__(self, loss: Callable):
        self.loss = loss
        
    def __call__(self, x):
        return self.loss(x)
        

In [43]:
class GaussPriorRegularizer(LambdaRegularizer):
    def _nll(x):
        return - tfd.Normal(mu, sigma).log_prob(x)
    
    def __init__(self, mu, sigma):
        self.mu = mu
        self.sigma = sigma
            
        super().__init__(loss = self._nll)        

In [44]:
class LinearLayer(kl.Layer):
    def __init__(self):
        super().__init__()
        self.a = self.add_weight("a", initializer="zeros", 
                            regularizer=GaussPriorRegularizer(0., 1.))
        
    def call(self, inputs):
        x = inputs
        return self.a * x 

In [45]:
ll = LinearLayer()

In [46]:
model = keras.Sequential([
    ll
])

In [47]:
x = tfd.Normal(0, 1).sample(10)
x

<tf.Tensor: id=349, shape=(10,), dtype=float32, numpy=
array([ 0.03183443,  1.0133514 ,  0.5784182 , -0.25800017, -0.11047503,
       -0.7759998 , -0.95392644, -0.3988249 ,  1.6209925 , -0.88297296],
      dtype=float32)>

In [48]:
y = model(x)

In [49]:
model.losses

[<tf.Tensor: id=366, shape=(), dtype=float32, numpy=0.9189385>]