# Neural Networks

> Define useful nn in `haiku`

In [None]:
# default_exp nets

In [None]:
# hide
%load_ext autoreload
%autoreload 2
from ipynb_path import *

In [None]:
# export
from cfnet.import_essentials import *
from cfnet.utils import validate_configs, sigmoid

In [None]:
# export
class DenseBlock(hk.Module):
    def __init__(self,
                output_size: int,
                dropout_rate: float = 0.3,
                name: Optional[str] = None):
        super().__init__(name=name)
        self.output_size = output_size
        self.dropout_rate = dropout_rate

    def __call__(self,
                x: jnp.ndarray,
                is_training: bool = True) -> jnp.ndarray:
        dropout_rate = self.dropout_rate if is_training else 0.0
        # he_uniform
        w_init = hk.initializers.VarianceScaling(2.0, 'fan_in', 'uniform')
        x = hk.Linear(self.output_size, w_init=w_init)(x)
        x = jax.nn.leaky_relu(x)
        x = hk.dropout(hk.next_rng_key(), dropout_rate, x)
        return x

In [None]:
# export
class MLP(hk.Module):
    def __init__(self,
                sizes: List[int],
                dropout_rate: float = 0.3,
                name: Optional[str] = None):
        super().__init__(name=name)
        self.sizes = sizes
        self.dropout_rate = dropout_rate

    def __call__(self,
                x: jnp.ndarray,
                is_training: bool = True) -> jnp.ndarray:
        for size in self.sizes:
            x = DenseBlock(size, self.dropout_rate)(x, is_training)
        return x

In [None]:
# exporti
class PredictiveModelConfigs(BaseParser):
    sizes: List[int]
    dropout_rate: float = 0.3

In [None]:
# export
class PredictiveModel(hk.Module):
    def __init__(
        self,
        m_config: Dict[str, Any],
        name: Optional[str] = None
    ):
        super().__init__(name=name)
        self.configs = validate_configs(m_config, PredictiveModelConfigs) #PredictiveModelConfigs(**m_config)

    def __call__(
        self,
        x: jnp.ndarray,
        is_training: bool = True
    ) -> jnp.ndarray:
        x = MLP(sizes=self.configs.sizes, dropout_rate=self.configs.dropout_rate)(x, is_training)
        x = hk.Linear(1)(x)
        x = sigmoid(x)
        return x

tests

In [None]:
def model_forward(x: jnp.ndarray, is_training: bool = True) -> jnp.ndarray:
    return PredictiveModel([10, 10], dropout_rate=0.3)(x, is_training)

net = hk.transform(model_forward)

Transformed(init=<function without_state.<locals>.init_fn at 0x7f8b944037a0>, apply=<function without_state.<locals>.apply_fn at 0x7f8b94403830>)

In [None]:
#| exporti
class CounterNetModelConfigs(BaseParser):
    enc_sizes: List[int]
    dec_sizes: List[int]
    exp_sizes: List[int]
    dropout_rate: float = 0.3

In [None]:
#| export
class CounterNetModel(hk.Module):

    def __init__(self,
                m_config: Dict[str, Any],
                name: Optional[str] = None):
        super().__init__(name=name)
        self.configs = validate_configs(m_config, CounterNetModelConfigs)

    def __call__(self,
                x: jnp.ndarray,
                is_training: bool = True) -> jnp.ndarray:
        input_shape = x.shape[-1]
        # encoder
        z = MLP(self.configs.enc_sizes, self.configs.dropout_rate, name="Encoder")(x, is_training)

        # prediction
        pred = MLP(self.configs.dec_sizes, self.configs.dropout_rate, name="Predictor")(z, is_training)
        y_hat = hk.Linear(1, name='Predictor')(pred)
        y_hat = sigmoid(y_hat)

        # explain
        z_exp = jnp.concatenate((z, pred), axis=-1)
        cf = MLP(self.configs.exp_sizes, self.configs.dropout_rate, name="Explainer")(z_exp, is_training)
        cf = hk.Linear(input_shape, name='Explainer')(cf)
        return y_hat, cf

In [None]:
def model_forward(x: jnp.ndarray, is_training: bool = True) -> jnp.ndarray:
    params = {
        "enc_sizes": [10, 10],
        "dec_sizes": [10, 10],
        "exp_sizes": [10, 10],
        "dropout_rate": 0.3
    }
    return CounterNetModel(**params)(x, is_training)

net = hk.transform(model_forward)


In [None]:
key = hk.PRNGSequence(42)

xs = random.normal(next(key), (1000, 10))

params = net.init(next(key), xs, is_training=True)
y = net.apply(params, next(key), xs, is_training=True)
jax.tree_map(lambda x: x.shape, params)


{'counter_net_model/Encoder/dense_block/linear': {'b': (10,), 'w': (10, 10)},
 'counter_net_model/Encoder/dense_block_1/linear': {'b': (10,), 'w': (10, 10)},
 'counter_net_model/Explainer/dense_block/linear': {'b': (10,), 'w': (20, 10)},
 'counter_net_model/Explainer/dense_block_1/linear': {'b': (10,),
  'w': (10, 10)},
 'counter_net_model/Predictor/dense_block/linear': {'b': (10,), 'w': (10, 10)},
 'counter_net_model/Predictor/dense_block_1/linear': {'b': (10,),
  'w': (10, 10)},
 'counter_net_model/linear': {'b': (1,), 'w': (10, 1)},
 'counter_net_model/linear_1': {'b': (10,), 'w': (10, 10)}}