# Network Training

@[Chaoming Wang](mailto:adaduo@outlook.com)

To maker your model powerful, you need to train your created network models. In this section, we are going to talk about how to train your network models.

In [1]:
import brainpy as bp
import brainpy.math as bm

bp.math.set_platform('cpu')

## Setup a ``RNNTrainer``

Once you create a model, setuping a structural trainer is just to instantiating a ``RNNTrainer``.

In [2]:
model = (
    bp.nn.Input(1)
    >>
    bp.nn.VanillaRNN(100)
    >>
    bp.nn.Dense(1)
)
model.initialize(1)

# set up a ridge regression trainer
trainer = bp.nn.BPTT(model, loss='mean_squared_error')

In the next, all you need is to provide your training data to the ``.fit()`` function.

The **training data** feeding into the ``.fit()`` function can be a tuple or a list of ``(X, Y)`` pair, or a callable function which generate ``(x, y)`` data pairs.

- If the providing training data is the ``(X, Y)`` data pair, ``X`` should be the input data which has the shape of `(num_sample, num_time, num_feature)`, ``Y`` should be the target data which has the shape of `(num_sample, num_time, num_feature)` for ``many-to-many`` training data mapping, or a data with the shape of `(num_sample, num_feature)` for ``many-to-final`` training data mapping.

![](../_static/rnn_training_mapping.png)


- If the training data is a callable function, it should generate a Python generator which yield the pair of ``(X, Y)`` data for training. For example,

```python

# when calling this function,
# it will create a Python generator.

def train_data():
  num_data = 10
  for _ in range(num_data):
     # The (X, Y) data pair should be:
     # - "X" is a tensor has the shape of
     #   "(num_batch, num_time, num_feature)"
     # - "Y" is a tensor has the shape of
     #   "(num_batch, num_time, num_feature)"
     #   or "(num_batch, num_feature)"
     xs = bm.random.rand(1, 20, 2)
     ys = bm.random.random((1, 20, 2))
     yield xs, ys
```


However, all these data constraints can be released when you customize your training procedures. Please see XXXXXX.

It is worthy to note that before fitting your data by calling ``.fit()`` function, you need to **initialize the model** by specifying the batch size your data are using. Otherwise, an error will cause.

## Supported training algorithms

Currently, BrainPy provides several ways to train recurrent neural networks, including ridge regression, FORCE learning, and back-propagation through time algorithms, etc. The full list of the supported training algorithms please see the [API documentation](../apis/auto/nn/runners.rst). Here we only talk about few of them.

### Ridge regression

## Shared parameters

Sometimes, there are some global parameters which are shared across all nodes. For example, the training or testing phase control parameter ``train=True/False``. Here, we use one simple model to demonstrate how to provide shared parameters when we calling models.

In [3]:
model = (
    bp.nn.Input(1)
    >>
    bp.nn.VanillaRNN(100)
    >>
    bp.nn.Dropout(0.3)
    >>
    bp.nn.Dense(1)
)
model.initialize(3)

These shared parameters can be provided as two kinds of ways:

- When you are using the instantiated model directly, you can provide them when calling this model.

In [4]:
model(bm.random.rand(3, 1), train=True)

JaxArray([[-2.1306934],
          [ 1.4046229],
          [ 1.2039466]], dtype=float32)

In [5]:
model(bm.random.rand(3, 1), train=False)

JaxArray([[-0.18169183],
          [-0.09682302],
          [-0.09607743]], dtype=float32)

- When you are using the structural runners like ``brainpy.nn.RNNRunner`` or ``brainpy.nn.BPTT`` trainer, you can warp all shared parameters in an argument ``shared_kwargs``.

In [6]:
runner = bp.nn.RNNRunner(model)

In [7]:
runner.predict(bm.random.random((3, 10, 1)),
               shared_kwargs={'train': True})

  0%|          | 0/10 [00:00<?, ?it/s]

JaxArray([[[-0.34313297],
           [-0.23117486],
           [-0.30522805],
           [-0.32745296],
           [ 0.34081894],
           [ 0.15859577],
           [-0.12556326],
           [-0.08884146],
           [-0.11673015],
           [ 0.07351794]],

          [[ 0.58661526],
           [-0.01611713],
           [-0.28982073],
           [-0.06223251],
           [-0.02380377],
           [-0.25092548],
           [ 0.1635698 ],
           [-0.21967886],
           [ 0.20144288],
           [-0.37496752]],

          [[-0.30316865],
           [-0.30312598],
           [-0.19248168],
           [-0.0053116 ],
           [-0.19092566],
           [-0.10831024],
           [-0.04929686],
           [ 0.00246187],
           [-0.00367273],
           [-0.03972476]]], dtype=float32)

In [8]:
runner.predict(bm.random.random((3, 10, 1)),
               shared_kwargs={'train': False})

  0%|          | 0/10 [00:00<?, ?it/s]

JaxArray([[[-0.04736265],
           [-0.03275637],
           [ 0.07194265],
           [-0.07143855],
           [-0.07852352],
           [ 0.0554626 ],
           [-0.07382635],
           [-0.00669023],
           [-0.04883794],
           [-0.0540965 ]],

          [[-0.06040927],
           [ 0.10728423],
           [-0.09903762],
           [-0.04275577],
           [ 0.03387596],
           [-0.07266021],
           [-0.01200375],
           [ 0.00651347],
           [-0.01666619],
           [-0.08301818]],

          [[-0.02191623],
           [-0.10554922],
           [ 0.01432531],
           [-0.03919244],
           [-0.05116755],
           [ 0.10237596],
           [-0.13605613],
           [ 0.02815703],
           [-0.02858308],
           [-0.08084894]]], dtype=float32)

However, it's worthy to note that ``shared_kwargs`` should only have several values. Different value of ``shared_kwargs`` will trigger recompilation. If ``shared_kwargs`` change significantly and frequently, you’d better not declare it as ``shared_kwargs``.