# Staged models

TODO: Write this as a deconstruction of an actual model, e.g. the original `SimpleFeedback`, into state operations

Equinox modules collect the parameters of a model into tree structures

- However, the actual operations that are performed may be complex
    - defined in `__call__`
    - and the tree structure of a module may not reflect exactly what computations are performed

Models can be seen as a series of operations on the state of a system

- `AbstractStagedModel` is a special kind of `eqx.Module` that makes us specify its computations as a series of *stages*
    - that is, operations that modify the state associated with that type of model
- Each type of `AbstractStagedModel` is associated with a particular type of `AbstractState`.
    - So, states tend to have a tree structure that reflects the model hierarchy, as well as the state operations that can be performed at each level of the model

For example, in our examples so far,

- `model.step` is an instance of `SimpleFeedback`, which is a type of `AbstractStagedModel`. 
- When `model.step` it is called,
    - it is passed an instance of `SimpleFeedbackState`, and
    - it executes several stages of modifications to `SimpleFeedbackState`, and returns the result.
    - One of the stages of `SimpleFeedback` is labeled `"nn_step"`. This stage involves calling instance of `SimpleStagedNetwork` that belongs to `SimpleFeedback`, which in our model we refer to as `model.step.net`.
- `SimpleStagedNetwork` is also a type of `AbstractStagedModel`
- When an instance of `SimpleStagedNetwork` is called, 
    - it is passed an instance of `SimpleStagedNetworkState`—when `model.step` calls `model.step.net` during the `"nn_step"` stage, it passes `states.network`.
    - it executes several stages of modifications to `SimpleStagedNetworkState`, and returns the result.
    - 
    
Note that `SimpleFeedback` (`model.step`) possesses a `SimpleStagedNetwork` (`model.step.net`), and that `SimpleFeedbackState` (`states`) possesses a `NetworkState` (`states.net`). 

Using `format_model_spec` to visualize the stages of a model

Inspecting `model_spec`

## Inspecting staged models 

Model spec

Example for a `SimpleStagedNetwork` that has an encoder, hidden noise, and readout.

```python
    @property
    def model_spec(self):
        """Specifies the stages of the model in terms of state operations.
        """
        return OrderedDict({
                'encoder': ModelStageSpec(
                    callable=lambda self: self._encode,
                    where_input=lambda input, state: ravel_pytree(input)[0],
                    where_state=lambda state: state.encoding,
                ),
                'hidden': ModelStageSpec(
                    callable=lambda self: self.hidden,
                    where_input=lambda input, state: state.encoding,
                    where_state=lambda state: state.hidden,
                ),
                'hidden_noise': ModelStageSpec(
                    callable=lambda self: self._add_hidden_noise,
                    where_input=lambda input, state: state.hidden,
                    where_state=lambda state: state.hidden,
                ),
                'readout': ModelStageSpec(
                    callable=lambda self: self._output,
                    where_input=lambda input, state: state.hidden,
                    where_state=lambda state: state.output,
                ),
        })      
```

## Using simple functions 

Wrappers.

## Using non-staged components 

- Using existing neural networks as controllers 
    - And the downside (intervenors)

## Writing a staged model

Example: similar to `SimpleFeedback`, but with two neural networks with a `Channel` between them?