In [27]:
import os 
os.environ["KERAS_BACKEND"]="jax"

import jax 
import jax.numpy as jnp 
import keras
import matplotlib.pyplot as plt 
import numpy as np

In [31]:
class MyLayer(keras.layers.Layer):
    def __init__(self,**kwargs):
        super().__init__(**kwargs)

    def build(self,input_shape):
        self.w=self.add_weight(
            shape=(input_shape[-1],1),
            initializer="random_normal",
            trainable=True
        )
        self.b=self.add_weight(
            shape=(1,),
            initializer="random_normal",
            trainable=True
        )
        self.built=True
    
    def call(self,x):
        # ret=x**2@self.w+self.b
        ret=keras.ops.matmul(x**2,self.w)+self.b
        return ret 
    
    def compute_output_shape(self,input_shape):
        output_shape=list(input_shape)
        output_shape[-1]=1
        return output_shape

class MyModel(keras.Model):
    def __init__(self,**kwargs):
        super().__init__(**kwargs)
        self.quadratic=MyLayer()

    def call(self,x):
        y=self.quadratic(x)
        return y 
    
    def build(self,input_shape):
        self.quadratic.build(input_shape)
        self.built=True

X=jnp.array([[0.0],[1.0]])
Y=jnp.array([[2.0],[3.0]])

model=MyModel()
model.build(input_shape=(None,1))
# model.summary()


optimizer=keras.optimizers.Adam(learning_rate=1e-1)
model.compile(optimizer=optimizer,loss=None)
model.fit(x=X,y=Y,validation_split=0,epochs=1000,verbose=0)

print(model.quadratic.w.numpy())
print(model.quadratic.b.numpy())

[[0.02263587]]
[-0.07732891]


In [None]:
class BNN(keras.layers.Layer):
    def __init__(self,units=1,**kwargs):
        super().__init__(**kwargs)
        self.units=units
    
    def build(self,input_shape):
        # w 
        self.w=None # shape: (input_shape[-1],units)
        self.w_mu=self.add_weight(shape=(input_shape[-1],self.units),initializer="random_normal",trainable=True)
        self.w_rho=self.add_weight(shape=(input_shape[-1],self.units),initializer="random_normal",trainable=True)

        # b
        self.b=None #shape: (units,)
        self.b_mu=self.add_weight(shape=(self.units,),initializer="random_normal",trainable=True)
        self.b_rho=self.add_weight(shape=(self.units,),initializer="random_normal",trainable=True)

        self.built=True 


    def call(self,x):
        # w 
        w_epsilon=keras.random.normal(shape=self.w_mu.shape)
        self.w=self.w_mu+keras.ops.log(1+keras.ops.exp(self.w_rho))*w_epsilon 

        # b 
        b_epsilon=keras.random.normal(shape=self.b_mu.shape)
        self.b=self.b_mu+keras.ops.log(1+keras.ops.exp(self.b_rho))*b_epsilon 







    def compute_output_shape(self, input_shape):
        output_shape=list(input_shape)
        output_shape[-1]=self.units 
        return output_shape 
    

keras.layers.Dense()