In [9]:
import dataclasses 
from typing import Sequence, Optional, Tuple

import jax.numpy as np 
import jax.random as npr 
import jax.nn as nn 

In [16]:
@dataclasses.dataclass 
class MLPConfig: 
    layer_sizes: Sequence[int] 
    init_scale: Optional[float]=1e-2
    activation: Optional[callable]=nn.relu
    use_bias: Optional[bool]=True

class MLP: 
    def __init__(self, config: MLPConfig): 
        self.config = config 
        self.initialize_params()

    @property 
    def params(self) -> Sequence[Tuple[np.ndarray]]: 
        return self._params

    @params.setter 
    def params(self, new_params: Sequence[Tuple[np.ndarray]]): 
        self._params = new_params

    def initialize_params(self, seed: Optional[int]=0): 
        self._params = [] 
        key: np.ndarray = npr.PRNGKey(seed)

        for i, j in zip(self.config.layer_sizes[:-1], self.config.layer_sizes[1:]): 
            weight_key, bias_key, current_key = npr.split(key, 3)
            weight = self.config.init_scale * npr.normal(weight_key, shape=(i, j))

            if self.config.use_bias: 
                bias = self.config.init_scale * npr.normal(bias_key, shape=(j,))
                self._params.append((weight, bias))
            else:
                self._params.append((weight,))


    def __call__(self, params, x: np.ndarray) -> np.ndarray: 
        if self.config.use_bias: 
            weight, bias = params[0]
            activation: np.ndarray = weight @ x + bias 
            for weight, bias in params[1:-1]: 
                activation = self.config.activation(weight @ x + bias)

            weight, bias = params[-1]
            return weight @ activation + bias
        else: 
            weight = params[0]
            activation: np.ndarray = weight @ x

            for weight in params[1:-1]: 
                activation = self.config.activation(weight @ x)

            weight = params[-1]
            return weight @ activation 

In [17]:
config = MLPConfig(
    layer_sizes=[1, 4, 4, 1], 
)
mlp = MLP(config)

In [18]:
mlp.params

[(Array([[ 0.00962801,  0.00155307, -0.00261443,  0.00898887]], dtype=float32),
  Array([-0.00368388,  0.00359172,  0.00011448, -0.00124997], dtype=float32)),
 (Array([[-0.0028582 , -0.000159  , -0.02473603, -0.00186446],
         [-0.0065618 , -0.01294555, -0.00973209,  0.00565696],
         [ 0.02238912, -0.00535222, -0.00336169,  0.00941552],
         [-0.00351734, -0.01040756,  0.00415718,  0.01074744]],      dtype=float32),
  Array([-0.00368388,  0.00359172,  0.00011448, -0.00124997], dtype=float32)),
 (Array([[ 0.00962801],
         [ 0.00155307],
         [-0.00261443],
         [ 0.00898887]], dtype=float32),
  Array([0.00578149], dtype=float32))]