# Configuration Showcase

This notebook serves as a showcase how the different configurations available in
`Trainax` can be depicted schematically.

In [6]:
import jax.numpy as jnp

import trainax as tx

## Supervised

A supervised configuration is special because all data can be pre-computed. No
`ref_stepper` or `residuum_fn` is needed on the fly (and hence also does not
have to be differentiable).

### One-Step supervised

<p align="center">
    <img src="https://ceyron.github.io/predictor-learning-setups/sup-1-none-true-primal.svg" width="400">
</p>

In [3]:
# The default is one-step supervised learning
tx.configuration.Supervised()

Supervised(
  num_rollout_steps=1,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  time_level_weights=f32[1]
)

### Two-Step supervised (rollout) Training

<p align="center">
    <img src="https://ceyron.github.io/predictor-learning-setups/sup-2-none-true-primal.svg" width="400">
</p>

We roll out the neural emulator for two autoregressive steps. Its parameters are
shared between the two predictions. Similarly, the `ref_stepper` is used to
create the reference trajectory; the loss is aggregated as a sum over the two
time levels.

In [4]:
tx.configuration.Supervised(num_rollout_steps=2)

Supervised(
  num_rollout_steps=2,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  time_level_weights=f32[2]
)

### Three-Step supervised (rollout) Training

<p align="center">
    <img src="https://ceyron.github.io/predictor-learning-setups/sup-3-none-true-primal.svg" width="400">
</p>

Same idead as above but with an additional rollout step.

In [5]:
tx.configuration.Supervised(num_rollout_steps=3)

Supervised(
  num_rollout_steps=3,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  time_level_weights=f32[3]
)

### Three-Step supervised (rollout) Training with loss only at final state

<p align="center">
    <img src="https://ceyron.github.io/predictor-learning-setups/sup-3-none-false-primal.svg" width="400">
</p>

The loss is only taken from the last step. Essentially, this corresponds to
weighting the time levels with $[0, 0, 1]$, respectively. (More weighting
options are possible, of course.)

In [7]:
tx.configuration.Supervised(
    num_rollout_steps=3, time_level_weights=jnp.array([0.0, 0.0, 1.0])
)

Supervised(
  num_rollout_steps=3,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  time_level_weights=f32[3]
)

### Three-Step supervised (rollout) Training with no backpropagation through time

(Displays the primal evaluation together with the cotangent flow; grey dashed
line indicates a cutted gradient.)

<p align="center">
    <img src="https://fkoehler.site/predictor-learning-setups/sup-3-none-true-no_net_bptt.svg" width="400">
</p>

This interrupts a gradient flow backward over the autoregressive network
execution. Gradients can still flow into the parameter space.

In [8]:
tx.configuration.Supervised(num_rollout_steps=3, cut_bptt=True)

Supervised(
  num_rollout_steps=3,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=True,
  cut_bptt_every=1,
  time_level_weights=f32[3]
)

### Four Steps supervised (rollout) Training with sparse backpropagation through time

<p align="center">
    <img src="https://fkoehler.site/predictor-learning-setups/sup-4-none-true-cut_every_2_net_bptt.svg" width="700">
</p>

Only every second backpropagation step is allowed to flow through the network.

In [9]:
tx.configuration.Supervised(num_rollout_steps=4, cut_bptt=True, cut_bptt_every=2)

Supervised(
  num_rollout_steps=4,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=True,
  cut_bptt_every=2,
  time_level_weights=f32[4]
)

## Diverted Chain

### Two-Steps with branch length one

<p align="center">
    <img src="https://fkoehler.site/predictor-learning-setups/div-2-1-true-primal.svg" width="500">
</p>

The `ref_stepper` is not run autoregressively for two steps from the initial
condition but rather for one step, branching off from the main chain created by
the emulator.

In [10]:
# `num_rollout_steps` referse to the number of autoregressive steps performed by
# the neural emulator
tx.configuration.DivertedChainBranchOne(num_rollout_steps=2)

DivertedChainBranchOne(
  num_rollout_steps=2,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  cut_div_chain=False,
  time_level_weights=f32[2]
)

In [11]:
# Alternatively, the general interface can be used
tx.configuration.DivertedChain(num_rollout_steps=2, num_branch_steps=1)

DivertedChain(
  num_rollout_steps=2,
  num_branch_steps=1,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  cut_div_chain=False,
  time_level_weights=f32[2],
  branch_level_weights=f32[1]
)

### Three-steps with branch length one

<p align="center">
    <img src="https://fkoehler.site/predictor-learning-setups/div-3-1-true-primal.svg" width="600">
</p>

In [12]:
tx.configuration.DivertedChainBranchOne(num_rollout_steps=3)

DivertedChainBranchOne(
  num_rollout_steps=3,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  cut_div_chain=False,
  time_level_weights=f32[3]
)

### Four-steps with branch length one

<p align="center">
    <img src="https://fkoehler.site/predictor-learning-setups/div-4-1-true-primal.svg" width="700">
</p>

In [13]:
tx.configuration.DivertedChainBranchOne(num_rollout_steps=4)

DivertedChainBranchOne(
  num_rollout_steps=4,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  cut_div_chain=False,
  time_level_weights=f32[4]
)

### Three-steps with branch length two

<p align="center">
    <img src="https://fkoehler.site/predictor-learning-setups/div-3-2-true-primal.svg" width="600">
</p>

In [14]:
# Can only be done with the general interface
tx.configuration.DivertedChain(num_rollout_steps=3, num_branch_steps=2)

DivertedChain(
  num_rollout_steps=3,
  num_branch_steps=2,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  cut_div_chain=False,
  time_level_weights=f32[3],
  branch_level_weights=f32[2]
)

In [None]:
# TODO: All the cutted versions ...