# Module

> Modules used for defining model architecture and training procedure, which are passed to `train_model`.

In [None]:
#| default_exp module

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
from nbdev import show_doc

In [None]:
#| export
from __future__ import annotations
from cfnet.import_essentials import *
from cfnet.nets import PredictiveModel, CounterNetModel
from cfnet.interfaces import BaseCFExplanationModule
from cfnet.datasets import TabularDataModule
from cfnet.logger import TensorboardLogger
from cfnet.utils import (
    validate_configs,
    sigmoid,
    cat_normalize,
    accuracy,
    proximity,
    make_model,
    init_net_opt,
    grad_update,
)
from fastcore.basics import patch
from functools import partial
from abc import ABC, abstractmethod
from copy import deepcopy

## Networks

Networks are [haiku.module](https://dm-haiku.readthedocs.io/en/latest/api.html#common-modules), 
which define model architectures.

In [None]:
#| export
class BaseNetwork(ABC):
    """BaseNetwork needs a `is_training` argument"""

    def __call__(self, *, is_training: bool):
        pass


In [None]:
#| export
#| hide
class DenseBlock(hk.Module):
    def __init__(
        self,
        output_size: int,  # Output dimensionality.
        dropout_rate: float = 0.3,  # Dropout rate.
        name: str | None = None,  # Name of the Module
    ):
        """A `DenseBlock` consists of a dense layer, followed by Leaky Relu and a dropout layer."""
        super().__init__(name=name)
        self.output_size = output_size
        self.dropout_rate = dropout_rate

    def __call__(self, x: jnp.ndarray, is_training: bool = True) -> jnp.ndarray:
        dropout_rate = self.dropout_rate if is_training else 0.0
        # he_uniform
        w_init = hk.initializers.VarianceScaling(2.0, "fan_in", "uniform")
        x = hk.Linear(self.output_size, w_init=w_init)(x)
        x = jax.nn.leaky_relu(x)
        x = hk.dropout(hk.next_rng_key(), dropout_rate, x)
        return x


In [None]:
show_doc(DenseBlock.__init__)

---

[source](https://github.com/birkhoffg/cfnet/tree/master/blob/master/cfnet/nets.py#L14){target="_blank" style="float:right; font-size:smaller"}

### DenseBlock.__init__

>      DenseBlock.__init__ (output_size:int, dropout_rate:float=0.3,
>                           name:Union[str,NoneType]=None)

A `DenseBlock` consists of a dense layer, followed by Leaky Relu and a dropout layer.

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| output_size | int |  | Output dimensionality. |
| dropout_rate | float | 0.3 | Dropout rate. |
| name | str \| None | None | Name of the Module |

In [None]:
#| export
#| hide
class MLP(hk.Module):
    def __init__(
        self,
        sizes: Iterable[int],  # Sequence of layer sizes.
        dropout_rate: float = 0.3,  # Dropout rate.
        name: str | None = None,  # Name of the Module
    ):
        """A `MLP` consists of a list of `DenseBlock` layers."""
        super().__init__(name=name)
        self.sizes = sizes
        self.dropout_rate = dropout_rate

    def __call__(self, x: jnp.ndarray, is_training: bool = True) -> jnp.ndarray:
        for size in self.sizes:
            x = DenseBlock(size, self.dropout_rate)(x, is_training)
        return x


In [None]:
show_doc(MLP.__init__)

---

[source](https://github.com/birkhoffg/cfnet/tree/master/blob/master/cfnet/nets.py#L37){target="_blank" style="float:right; font-size:smaller"}

### MLP.__init__

>      MLP.__init__ (sizes:Iterable[int], dropout_rate:float=0.3,
>                    name:Union[str,NoneType]=None)

A `MLP` consists of a list of `DenseBlock` layers.

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| sizes | Iterable[int] |  | Sequence of layer sizes. |
| dropout_rate | float | 0.3 | Dropout rate. |
| name | str \| None | None | Name of the Module |

## Predictive Model

In [None]:
#| exporti
class PredictiveModelConfigs(BaseParser):
    """Configurator of `PredictiveModel`."""

    sizes: List[int]  # Sequence of layer sizes.
    dropout_rate: float = 0.3  # Dropout rate.


In [None]:
#| export
#| hide
class PredictiveModel(hk.Module):
    def __init__(
        self,
        m_config: Dict[
            str, Any
        ],  # Model configs which contain configs in `PredictiveModelConfigs`.
        name: Optional[str] = None,  # Name of the module.
    ):
        """A basic predictive model for binary classification."""
        super().__init__(name=name)
        self.configs = validate_configs(
            m_config, PredictiveModelConfigs
        )  # PredictiveModelConfigs(**m_config)

    def __call__(self, x: jnp.ndarray, is_training: bool = True) -> jnp.ndarray:
        x = MLP(sizes=self.configs.sizes, dropout_rate=self.configs.dropout_rate)(
            x, is_training
        )
        x = hk.Linear(1)(x)
        x = sigmoid(x)
        return x


In [None]:
show_doc(PredictiveModelConfigs)

---

[source](https://github.com/birkhoffg/cfnet/tree/master/blob/master/cfnet/nets.py#L53){target="_blank" style="float:right; font-size:smaller"}

### PredictiveModelConfigs

>      PredictiveModelConfigs (sizes:List[int], dropout_rate:float=0.3)

Configurator of `PredictiveModel`.

In [None]:
show_doc(PredictiveModel.__init__)

---

[source](https://github.com/birkhoffg/cfnet/tree/master/blob/master/cfnet/nets.py#L63){target="_blank" style="float:right; font-size:smaller"}

### PredictiveModel.__init__

>      PredictiveModel.__init__ (m_config:Dict[str,Any],
>                                name:Union[str,NoneType]=None)

A basic predictive model for binary classification.

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| m_config | Dict[str, Any] |  | Model configs which contain configs in `PredictiveModelConfigs`. |
| name | Optional[str] | None | Name of the module. |

Specify model configurations (via `dict`).

In [None]:
m_configs = {
    "sizes": [50, 20, 10],
    "dropout_rate": 0.3 # optional
}

Use `make_model` to create a `haiku.Transformed` model.

In [None]:
from cfnet.utils import make_model

In [None]:
net = make_model(m_configs, PredictiveModel)

We make some random data.

In [None]:
key = hk.PRNGSequence(42)
xs = random.normal(next(key), (1000, 10))

We can then initalize the model

In [None]:
params = net.init(next(key), xs, is_training=True)

We can view model's structure via `jax.tree_map`.

In [None]:
jax.tree_map(lambda x: x.shape, params)

{'predictive_model/linear': {'b': (1,), 'w': (10, 1)},
 'predictive_model/mlp/dense_block/linear': {'b': (50,), 'w': (10, 50)},
 'predictive_model/mlp/dense_block_1/linear': {'b': (20,), 'w': (50, 20)},
 'predictive_model/mlp/dense_block_2/linear': {'b': (10,), 'w': (20, 10)}}

Model output is produced via `apply` function.

In [None]:
y = net.apply(params, next(key), xs, is_training=True)

For more usage of `haiku.module`, please refer to 
[Haiku documentation](https://dm-haiku.readthedocs.io/en/latest/api.html#haiku-fundamentals).

## CounterNet Model

In [None]:
#| exporti
class CounterNetModelConfigs(BaseParser):
    """Configurator of `CounterNetModel`."""

    enc_sizes: List[int]
    dec_sizes: List[int]
    exp_sizes: List[int]
    dropout_rate: float = 0.3


In [None]:
#| export
#| hide
class CounterNetModel(hk.Module):
    def __init__(
        self,
        m_config: Dict[
            str, Any
        ],  # Model configs which contain configs in `CounterNetModelConfigs`.
        name: Optional[str] = None,  # Name of the module.
    ):
        """CounterNet model architecture."""
        super().__init__(name=name)
        self.configs = validate_configs(m_config, CounterNetModelConfigs)

    def __call__(self, x: jnp.ndarray, is_training: bool = True) -> jnp.ndarray:
        input_shape = x.shape[-1]
        # encoder
        z = MLP(self.configs.enc_sizes, self.configs.dropout_rate, name="Encoder")(
            x, is_training
        )

        # prediction
        pred = MLP(self.configs.dec_sizes, self.configs.dropout_rate, name="Predictor")(
            z, is_training
        )
        y_hat = hk.Linear(1, name="Predictor")(pred)
        y_hat = sigmoid(y_hat)

        # explain
        z_exp = jnp.concatenate((z, pred), axis=-1)
        cf = MLP(self.configs.exp_sizes, self.configs.dropout_rate, name="Explainer")(
            z_exp, is_training
        )
        cf = hk.Linear(input_shape, name="Explainer")(cf)
        return y_hat, cf


In [None]:
show_doc(CounterNetModelConfigs)

---

[source](https://github.com/birkhoffg/cfnet/tree/master/blob/master/cfnet/nets.py#L82){target="_blank" style="float:right; font-size:smaller"}

### CounterNetModelConfigs

>      CounterNetModelConfigs (enc_sizes:List[int], dec_sizes:List[int],
>                              exp_sizes:List[int], dropout_rate:float=0.3)

Create a new model by parsing and validating input data from keyword arguments.

Raises ValidationError if the input data cannot be parsed to form a valid model.

In [None]:
show_doc(CounterNetModel.__init__)

---

[source](https://github.com/birkhoffg/cfnet/tree/master/blob/master/cfnet/nets.py#L95){target="_blank" style="float:right; font-size:smaller"}

### CounterNetModel.__init__

>      CounterNetModel.__init__ (m_config:Dict[str,Any],
>                                name:Union[str,NoneType]=None)

CounterNet model architecture.

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| m_config | Dict[str, Any] |  | Model configs which contain configs in `CounterNetModelConfigs`. |
| name | Optional[str] | None | Name of the module. |

## Training Modules API

In [None]:
#| hide
class BaseTrainingModule(ABC):
    pass

@patch(as_prop=True)
def logger(
    self:BaseTrainingModule
) -> TensorboardLogger | None:
    """A logger property"""
    pass

@patch
def log(self:BaseTrainingModule, 
        name: str, # Name of the log
        value: Any # value
    ) -> None:
    pass

In [None]:
#| export
class BaseTrainingModule(ABC):
    hparams: Dict[str, Any]
    logger: TensorboardLogger | None

    def save_hyperparameters(self, configs: Dict[str, Any]) -> Dict[str, Any]:
        self.hparams = deepcopy(configs)
        return self.hparams

    def init_logger(self, logger: TensorboardLogger):
        self.logger = logger

    def log(self, name: str, value: Any):
        self.log_dict({name: value})

    def log_dict(self, dictionary: Dict[str, Any]):
        if self.logger:
            # self.logger.log({k: np.asarray(v) for k, v in dictionary.items()})
            self.logger.log_dict(dictionary)
        else:
            raise ValueError("Logger has not been initliazed.")

    @abstractmethod
    def init_net_opt(
        self, data_module: TabularDataModule, key: random.PRNGKey
    ) -> Tuple[hk.Params, optax.OptState]:
        pass

    @abstractmethod
    def training_step(
        self,
        params: hk.Params,
        opt_state: optax.OptState,
        rng_key: random.PRNGKey,
        batch: Tuple[jnp.array, jnp.array],
    ) -> Tuple[hk.Params, optax.OptState]:
        pass

    @abstractmethod
    def validation_step(
        self,
        params: hk.Params,
        rng_key: random.PRNGKey,
        batch: Tuple[jnp.array, jnp.array],
    ) -> Dict[str, Any]:
        pass


## Predictive Training Module

In [None]:
#| export
class PredictiveTrainingModuleConfigs(BaseParser):
    lr: float


In [None]:
#| export
class PredictiveTrainingModule(BaseTrainingModule):
    def __init__(self, m_configs: Dict[str, Any]):
        self.save_hyperparameters(m_configs)
        self.net = make_model(m_configs, PredictiveModel)
        self.configs = validate_configs(m_configs, PredictiveTrainingModuleConfigs)
        # self.configs = PredictiveTrainingModuleConfigs(**m_configs)
        self.opt = optax.adam(learning_rate=self.configs.lr)

    @partial(jax.jit, static_argnames=["self", "is_training"])
    def forward(self, params, rng_key, x, is_training: bool = True):
        return self.net.apply(params, rng_key, x, is_training=is_training)

    def init_net_opt(self, data_module, key):
        params, opt_state = init_net_opt(
            self.net, self.opt, X=data_module.get_sample_X(), key=key
        )
        return params, opt_state

    def loss_fn(self, params, rng_key, batch, is_training: bool = True):
        x, y = batch
        y_pred = self.net.apply(params, rng_key, x, is_training=is_training)
        return jnp.mean(vmap(optax.l2_loss)(y_pred, y))

    # def _training_step(self, params, opt_state, rng_key, batch):
    #     grads = jax.grad(self.loss_fn)(params, rng_key, batch)
    #     upt_params, opt_state = grad_update(grads, params, opt_state, self.opt)
    #     return upt_params, opt_state

    @partial(jax.jit, static_argnames=["self"])
    def _training_step(self, params, opt_state, rng_key, batch):
        grads = jax.grad(self.loss_fn)(params, rng_key, batch)
        upt_params, opt_state = grad_update(grads, params, opt_state, self.opt)
        return upt_params, opt_state

    def training_step(self, params, opt_state, rng_key, batch):
        params, opt_state = self._training_step(params, opt_state, rng_key, batch)

        loss = self.loss_fn(params, rng_key, batch)
        self.log_dict({"train/train_loss_1": loss.item()})
        return params, opt_state

    def validation_step(self, params, rng_key, batch):
        x, y = batch
        y_pred = self.net.apply(params, rng_key, x, is_training=False)
        loss = self.loss_fn(params, rng_key, batch, is_training=False)
        logs = {"val/val_loss": loss.item(), "val/val_accuracy": accuracy(y, y_pred)}
        self.log_dict(logs)


## CounterNet Training Module

In [None]:
#| export 
def partition_trainable_params(params: hk.Params, trainable_name: str):
    trainable_params, non_trainable_params = hk.data_structures.partition(
        lambda m, n, p: trainable_name in m, params
    )
    return trainable_params, non_trainable_params


In [None]:
#| export
class CounterNetTrainingModuleConfigs(BaseParser):
    lr: float
    lambda_1: float
    lambda_2: float
    lambda_3: float
    use_immutable: bool = True


In [None]:
#| export
def project_immutable_features(x, cf: jnp.DeviceArray, imutable_idx_list: List[int]):
    cf = cf.at[:, imutable_idx_list].set(x[:, imutable_idx_list])
    return cf


class CounterNetTrainingModule(BaseTrainingModule, BaseCFExplanationModule):
    name = "CounterNet"

    def __init__(self, m_configs: Dict[str, Any]):
        self.save_hyperparameters(m_configs)
        self.net = make_model(m_configs, CounterNetModel)
        self.configs = validate_configs(m_configs, CounterNetTrainingModuleConfigs)
        # self.configs = CounterNetTrainingModuleConfigs(**m_configs)
        self.opt_1 = optax.adam(learning_rate=self.configs.lr)
        self.opt_2 = optax.adam(learning_rate=self.configs.lr)

    def init_net_opt(self, data_module, key):
        self.update_cat_info(data_module)
        # manually init multiple opts
        params, opt_1_state = init_net_opt(
            self.net, self.opt_1, X=data_module.get_sample_X(), key=key
        )
        trainable_params, _ = partition_trainable_params(
            params, trainable_name="counter_net_model/Explainer"
        )
        opt_2_state = self.opt_2.init(trainable_params)
        return params, (opt_1_state, opt_2_state)

    @partial(jax.jit, static_argnames=["self", "is_training"])
    def forward(self, params, rng_key, x, is_training: bool = True):
        # first forward to get y_pred and normalized cf
        y_pred, cf = self.net.apply(params, rng_key, x, is_training=is_training)
        # cf = cf_res + x
        cf = cat_normalize(cf, self.cat_arrays, self.cat_idx, hard=not is_training)
        # project immutable features
        if self.configs.use_immutable:
            cf = project_immutable_features(x, cf, self.imutable_idx_list)
        # second forward to calulate cf_y
        cf_y, _ = self.net.apply(params, rng_key, cf, is_training=is_training)
        return y_pred, cf, cf_y

    def predict(self, params, rng_key, x):
        y_pred, _ = self.net.apply(params, rng_key, x, is_training=False)
        return y_pred

    def generate_cfs(self, X: chex.ArrayBatched, params, rng_key) -> chex.ArrayBatched:
        y_pred, cfs = self.net.apply(params, rng_key, X, is_training=False)
        # cfs = cfs + X
        cfs = cat_normalize(cfs, self.cat_arrays, self.cat_idx, hard=True)
        if self.configs.use_immutable:
            cfs = project_immutable_features(X, cfs, self.imutable_idx_list)
        return cfs

    def loss_fn_1(self, y_pred, y):
        return jnp.mean(vmap(optax.l2_loss)(y_pred, y))

    def loss_fn_2(self, cf_y, y_prime):
        return jnp.mean(vmap(optax.l2_loss)(cf_y, y_prime))

    def loss_fn_3(self, x, cf):
        return jnp.mean(vmap(optax.l2_loss)(x, cf))

    # def loss_fns(self, params, rng_key, batch, is_training: bool = True):
    #     x, y = batch
    #     y_pred, cf, cf_y = self.forward(params, rng_key, x, is_training=is_training)
    #     y_prime = 1 - jnp.round(y_pred)
    #     return self.loss_fn_1(y_pred, y), self.loss_fn_2(cf_y, y_prime), self.loss_fn_3(x, cf)

    def pred_loss_fn(self, params, rng_key, batch, is_training: bool = True):
        x, y = batch
        y_pred, cf = self.net.apply(params, rng_key, x, is_training=is_training)
        return self.configs.lambda_1 * self.loss_fn_1(y_pred, y)

    def exp_loss_fn(
        self,
        trainable_params,
        non_trainable_params,
        rng_key,
        batch,
        is_training: bool = True,
    ):
        # merge trainable and non_trainable params
        params = hk.data_structures.merge(trainable_params, non_trainable_params)
        x, y = batch
        y_pred, cf, cf_y = self.forward(params, rng_key, x, is_training=is_training)
        y_prime = 1 - jnp.round(y_pred)
        loss_2, loss_3 = self.loss_fn_2(cf_y, y_prime), self.loss_fn_3(x, cf)
        return self.configs.lambda_2 * loss_2 + self.configs.lambda_3 * loss_3

    def _predictor_step(self, params, opt_state, rng_key, batch):
        grads = jax.grad(self.pred_loss_fn)(params, rng_key, batch)
        upt_params, opt_state = grad_update(grads, params, opt_state, self.opt_1)
        return upt_params, opt_state

    def _explainer_step(self, params, opt_state, rng_key, batch):
        trainable_params, non_trainable_params = partition_trainable_params(
            params, trainable_name="counter_net_model/Explainer"
        )
        grads = jax.grad(self.exp_loss_fn)(
            trainable_params, non_trainable_params, rng_key, batch
        )
        upt_trainable_params, opt_state = grad_update(
            grads, trainable_params, opt_state, self.opt_2
        )
        upt_params = hk.data_structures.merge(
            upt_trainable_params, non_trainable_params
        )
        return upt_params, opt_state

    @partial(jax.jit, static_argnames=["self"])
    def _training_step(
        self,
        params: hk.Params,
        opts_state: Tuple[optax.GradientTransformation, optax.GradientTransformation],
        rng_key: random.PRNGKey,
        batch: Tuple[jnp.array, jnp.array],
    ):
        opt_1_state, opt_2_state = opts_state
        params, opt_1_state = self._predictor_step(params, opt_1_state, rng_key, batch)
        upt_params, opt_2_state = self._explainer_step(
            params, opt_2_state, rng_key, batch
        )
        return upt_params, (opt_1_state, opt_2_state)

    def _training_step_logs(self, params, rng_key, batch):
        x, y = batch
        y_pred, cf, cf_y = self.forward(params, rng_key, x, is_training=False)
        y_prime = 1 - jnp.round(y_pred)

        loss_1, loss_2, loss_3 = (
            self.loss_fn_1(y_pred, y),
            self.loss_fn_2(cf_y, y_prime),
            self.loss_fn_3(x, cf),
        )
        logs = {
            "train/train_loss_1": loss_1.item(),
            "train/train_loss_2": loss_2.item(),
            "train/train_loss_3": loss_3.item(),
        }
        return logs

    def training_step(
        self,
        params: hk.Params,
        opts_state: Tuple[optax.OptState, optax.OptState],
        rng_key: random.PRNGKey,
        batch: Tuple[jnp.array, jnp.array],
    ) -> Tuple[hk.Params, Tuple[optax.OptState, optax.OptState]]:
        upt_params, (opt_1_state, opt_2_state) = self._training_step(
            params, opts_state, rng_key, batch
        )

        logs = self._training_step_logs(upt_params, rng_key, batch)
        self.log_dict(logs)
        return upt_params, (opt_1_state, opt_2_state)

    def validation_step(self, params, rng_key, batch):
        x, y = batch
        y_pred, cf, cf_y = self.forward(params, rng_key, x, is_training=False)
        y_prime = 1 - jnp.round(y_pred)

        loss_1, loss_2, loss_3 = (
            self.loss_fn_1(y_pred, y),
            self.loss_fn_2(cf_y, y_prime),
            self.loss_fn_3(x, cf),
        )
        loss_1, loss_2, loss_3 = map(np.asarray, (loss_1, loss_2, loss_3))
        logs = {
            "val/accuracy": accuracy(y, y_pred),
            "val/validity": accuracy(cf_y, y_prime),
            "val/proximity": proximity(x, cf),
            "val/val_loss_1": loss_1,
            "val/val_loss_2": loss_2,
            "val/val_loss_3": loss_3,
            "val/val_loss": loss_1 + loss_2 + loss_3,
        }
        self.log_dict(logs)
        return logs

In [None]:
#| hide
from cfnet.train import train_model, TensorboardLogger
from cfnet.datasets import TabularDataModule

data_configs = {
    "data_dir": "assets/data/s_adult.csv",
    "data_name": "adult",
    "batch_size": 128,
    'sample_frac': 0.1,
    "continous_cols": [
        "age",
        "hours_per_week"
    ],
    "discret_cols": [
        "workclass",
        "education",
        "marital_status",
        "occupation",
        "race",
        "gender"
    ],
}
# dm = 
m_configs = {
    "enc_sizes": [50,10],
    "dec_sizes": [10],
    "exp_sizes": [50, 50],
    "dropout_rate": 0.3,
    "lr": 0.003,
    "lambda_1": 1.0,
    "lambda_3": 0.1,
    "lambda_2": 0.2,
}

t_configs = {
    'n_epochs': 1,
    'monitor_metrics': 'val/val_loss',
    'seed': 42,
    "batch_size": 256
}


In [None]:
#| hide
params, opts = train_model(
    CounterNetTrainingModule(m_configs),
    TabularDataModule(data_configs),
    t_configs
)

  leaves, treedef = jax.tree_flatten(tree)
  return jax.tree_unflatten(treedef, leaves)
Epoch 0: 100%|██████████| 20/20 [00:09<00:00,  2.18batch/s, train/train_loss_1=0.0532, train/train_loss_2=0.211, train/train_loss_3=0.149]
  for x in jax.tree_leaves(state):


In [None]:
#| hide
jax.tree_map(lambda x: x.shape, params)

{'counter_net_model/Encoder/dense_block/linear': {'b': (50,), 'w': (29, 50)},
 'counter_net_model/Encoder/dense_block_1/linear': {'b': (10,), 'w': (50, 10)},
 'counter_net_model/Explainer/dense_block/linear': {'b': (50,), 'w': (20, 50)},
 'counter_net_model/Explainer/dense_block_1/linear': {'b': (50,),
  'w': (50, 50)},
 'counter_net_model/Explainer_1': {'b': (29,), 'w': (50, 29)},
 'counter_net_model/Predictor/dense_block/linear': {'b': (10,), 'w': (10, 10)},
 'counter_net_model/Predictor_1': {'b': (1,), 'w': (10, 1)}}

In [None]:
#| hide
# partition_trainable_params(params, trainable_name='counter_net_model/Explainer')

({'counter_net_model/Explainer/dense_block/linear': {'b': DeviceArray([ -6.0861936 ,  11.848296  ,  -0.35694602, -11.134405  ,
                 -7.616356  ,  -9.1261015 ,  -9.20692   ,  -6.5062113 ,
                 -8.090335  , -10.452749  ,   0.06236552,  -9.428289  ,
                 -5.8404965 ,  -9.6401205 ,  -9.723112  ,  -9.664204  ,
                 -7.1477385 , -10.638865  ,  -5.0094237 , -13.039403  ,
                  0.8207084 ,  -8.677328  ,  -8.581181  ,  -8.168017  ,
                -13.538818  ,  -7.554711  , -10.01346   ,  -6.9791183 ,
                 -8.21111   ,  -9.719664  ,   8.602022  ,  -8.184092  ,
                -10.69423   ,  -7.102236  , -12.054253  ,  -9.991073  ,
                 -9.018837  ,  -0.1831892 , -11.176291  ,  -9.001131  ,
                 -7.3955545 ,  -2.1222026 ,  -8.371564  ,  -7.8714285 ,
                 -0.22489327,  -9.471412  ,  -9.532182  ,  -3.8748262 ,
                 -9.962095  ,  -0.8532969 ], dtype=float32),
   'w': DeviceArray(