# Neural Networks

> Define useful nn in `haiku`

In [None]:
# default_exp nets

In [None]:
# hide
%load_ext autoreload
%autoreload 2
from ipynb_path import *

In [None]:
# export
from cfnet.import_essentials import *
from cfnet.utils import validate_configs, sigmoid

  PyTreeDef = type(jax.tree_structure(None))


In [None]:
# export
class DenseBlock(hk.Module):
    def __init__(self,
                output_size: int,
                dropout_rate: float = 0.3,
                name: Optional[str] = None):
        super().__init__(name=name)
        self.output_size = output_size
        self.dropout_rate = dropout_rate

    def __call__(self,
                x: jnp.ndarray,
                is_training: bool = True) -> jnp.ndarray:
        dropout_rate = self.dropout_rate if is_training else 0.0
        # he_uniform
        w_init = hk.initializers.VarianceScaling(2.0, 'fan_in', 'uniform')
        x = hk.Linear(self.output_size, w_init=w_init)(x)
        x = jax.nn.leaky_relu(x)
        x = hk.dropout(hk.next_rng_key(), dropout_rate, x)
        return x

In [None]:
# export
class MLP(hk.Module):
    def __init__(self,
                sizes: List[int],
                dropout_rate: float = 0.3,
                name: Optional[str] = None):
        super().__init__(name=name)
        self.sizes = sizes
        self.dropout_rate = dropout_rate

    def __call__(self,
                x: jnp.ndarray,
                is_training: bool = True) -> jnp.ndarray:
        for size in self.sizes:
            x = DenseBlock(size, self.dropout_rate)(x, is_training)
        return x

In [None]:
# exporti
class PredictiveMLPConfigs(BaseParser):
    sizes: List[int]
    dropout_rate: float = 0.3

In [None]:
# export
class PredictiveMLP(hk.Module):
    def __init__(
        self,
        m_config: Dict[str, Any],
        name: Optional[str] = None
    ):
        super().__init__(name=name)
        self.configs = validate_configs(m_config, PredictiveMLPConfigs) #PredictiveModelConfigs(**m_config)

    def __call__(
        self,
        x: jnp.ndarray,
        is_training: bool = True
    ) -> jnp.ndarray:
        x = MLP(sizes=self.configs.sizes, dropout_rate=self.configs.dropout_rate)(x, is_training)
        x = hk.Linear(1)(x)
        x = sigmoid(x)
        return x

In [None]:
# export
class PredictivConvNet(hk.Module):
    def __init__(
        self, name = None
    ):
        super().__init__(name=name)

    def __call__(self, x: jnp.ndarray, is_training: bool = True):
        x = hk.Sequential([
            hk.Conv2D(output_channels=32, kernel_shape=(3, 3), padding="SAME"),
            jax.nn.leaky_relu,
            hk.Conv2D(output_channels=64, kernel_shape=(3, 3), padding="SAME"),
            jax.nn.leaky_relu,
            hk.Flatten(),
            hk.Linear(256),
            jax.nn.leaky_relu,
            hk.Linear(1),
        ])(x)
        return sigmoid(x)


In [None]:
# hide
from cfnet.utils import make_model
from cfnet.datasets import MNISTDataModule

net = make_model(None, PredictivConvNet)
key = hk.PRNGSequence(42)

dm = MNISTDataModule({})

xs = random.normal(next(key), (1000, 28, 28))

params = net.init(next(key), dm.get_sample_X(), is_training=True)
y = net.apply(params, next(key), xs, is_training=True)
jax.tree_util.tree_map(lambda x: x.shape, params)


[autoreload of cfnet.datasets failed: Traceback (most recent call last):
  File "/home/birk/mambaforge-pypy3/envs/cfnet/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 245, in check
    superreload(m, reload, self.old_objects)
  File "/home/birk/mambaforge-pypy3/envs/cfnet/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 410, in superreload
    update_generic(old_obj, new_obj)
  File "/home/birk/mambaforge-pypy3/envs/cfnet/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 347, in update_generic
    update(a, b)
  File "/home/birk/mambaforge-pypy3/envs/cfnet/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 317, in update_class
    update_instances(old, new)
  File "/home/birk/mambaforge-pypy3/envs/cfnet/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 280, in update_instances
    ref.__class__ = new
  File "pydantic/main.py", line 358, in pydantic.main.BaseModel.__setattr__
ValueError: "MNISTD

train_X.shape: (13007, 28, 28); train_y.shape: (13007,) 
test_X.shape: (2163, 28, 28); test_y.shape: (2163,) 


  leaves, treedef = jax.tree_flatten(tree)
  return jax.tree_unflatten(treedef, leaves)


{'predictiv_conv_net/conv2_d': {'b': (32,), 'w': (3, 3, 28, 32)},
 'predictiv_conv_net/conv2_d_1': {'b': (64,), 'w': (3, 3, 32, 64)},
 'predictiv_conv_net/linear': {'b': (256,), 'w': (1792, 256)},
 'predictiv_conv_net/linear_1': {'b': (1,), 'w': (256, 1)}}

In [None]:
dm.get_sample_X()

array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       ...,

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 

In [None]:
# exporti
class CounterNetMLPConfigs(BaseParser):
    enc_sizes: List[int]
    dec_sizes: List[int]
    exp_sizes: List[int]
    dropout_rate: float = 0.3

In [None]:
# export
class CounterNetMLP(hk.Module):

    def __init__(self,
                m_config: Dict[str, Any],
                name: Optional[str] = None):
        super().__init__(name=name)
        self.configs = validate_configs(m_config, CounterNetMLPConfigs)

    def __call__(self,
                x: jnp.ndarray,
                is_training: bool = True) -> jnp.ndarray:
        input_shape = x.shape[-1]
        # encoder
        z = MLP(self.configs.enc_sizes, self.configs.dropout_rate, name="Encoder")(x, is_training)

        # prediction
        pred = MLP(self.configs.dec_sizes, self.configs.dropout_rate, name="Predictor")(z, is_training)
        y_hat = hk.Linear(1, name='Predictor')(pred)
        y_hat = sigmoid(y_hat)

        # explain
        z_exp = jnp.concatenate((z, pred), axis=-1)
        cf = MLP(self.configs.exp_sizes, self.configs.dropout_rate, name="Explainer")(z_exp, is_training)
        cf = hk.Linear(input_shape, name='Explainer')(cf)
        return y_hat, cf

In [None]:
# exporti
class ConvExplainer(hk.Module):
    def __init__(
        self, 
        flaten_shape, 
        z_shape_3d,
        name: Optional[str] = None
    ):
        self.flatten_shape = flaten_shape
        self.z_shape_3d = z_shape_3d
        super().__init__(name=name)

    def __call__(self, x): 
        x = hk.Linear(self.flatten_shape[-1], name="Explainer")(x)
        x = x.reshape(self.z_shape_3d)
        x = jax.nn.leaky_relu(x)
        x = hk.Conv2DTranspose(output_channels=4, kernel_shape=(3, 3), padding='SAME')(x)
        x = jax.nn.leaky_relu(x)
        x = hk.Conv2DTranspose(output_channels=1, kernel_shape=(3, 3), padding='SAME')(x)
        x = jnp.tanh(x)
        return x

In [None]:
# export
class CounterNetConv(hk.Module):
    def __init__(
        self,
        name: Optional[str] = None
    ):
        super().__init__(name=name)

    def __call__(self,
                x: jnp.ndarray,
                is_training: bool = True) -> jnp.ndarray:
        x = jnp.expand_dims(x, axis=-1)
        # encoder
        z = hk.Sequential([
            hk.Conv2D(output_channels=4, kernel_shape=(3, 3), padding="SAME"),
            jax.nn.leaky_relu, 
            hk.Conv2D(output_channels=16, kernel_shape=(3, 3), padding="SAME"),
            jax.nn.leaky_relu,
        ], name='Encoder')(x)
        z_shape_3d = z.shape
        z = hk.Flatten()(z)
        z_shape_flattened = z.shape

        # prediction
        pred = hk.Sequential([
            hk.Linear(50),
            jax.nn.leaky_relu, 
        ], name='Predictor')(z)
        y_hat = hk.Linear(1, name='Predictor')(pred)
        y_hat = sigmoid(y_hat)

        # explain
        z_exp = jnp.concatenate((z, pred), axis=-1)
        # z_exp = hk.Linear(z_shape_flattened[-1], name="Explainer")(z_exp)
        # z_exp = z_exp.reshape(z_shape_3d)

        # cf = MLP(self.configs.exp_sizes, self.configs.dropout_rate, name="Explainer")(z_exp)
        cf = ConvExplainer(z_shape_flattened, z_shape_3d, name='Explainer')(z_exp)
        return y_hat, cf.squeeze()

In [None]:
# hide
m_configs = {
    "enc_sizes": [10, 10],
    "dec_sizes": [10, 10],
    "exp_sizes": [10, 10],
    "dropout_rate": 0.3
}

In [None]:
from cfnet.utils import make_model

net = make_model(m_configs, CounterNetMLP)
key = hk.PRNGSequence(42)

xs = random.normal(next(key), (1000, 10))

params = net.init(next(key), xs, is_training=True)
y = net.apply(params, next(key), xs, is_training=True)
jax.tree_map(lambda x: x.shape, params)


{'counter_net_mlp/Encoder/dense_block/linear': {'b': (10,), 'w': (10, 10)},
 'counter_net_mlp/Encoder/dense_block_1/linear': {'b': (10,), 'w': (10, 10)},
 'counter_net_mlp/Explainer/dense_block/linear': {'b': (10,), 'w': (20, 10)},
 'counter_net_mlp/Explainer/dense_block_1/linear': {'b': (10,), 'w': (10, 10)},
 'counter_net_mlp/Explainer_1': {'b': (10,), 'w': (10, 10)},
 'counter_net_mlp/Predictor/dense_block/linear': {'b': (10,), 'w': (10, 10)},
 'counter_net_mlp/Predictor/dense_block_1/linear': {'b': (10,), 'w': (10, 10)},
 'counter_net_mlp/Predictor_1': {'b': (1,), 'w': (10, 1)}}