# Configuration Showcase

This notebook serves as a showcase how the different configurations available in
`Trainax` can be depicted schematically.

In [1]:
import jax.numpy as jnp

import trainax as tx

  from tqdm.autonotebook import tqdm


## Supervised

A supervised configuration is special because all data can be pre-computed. No
`ref_stepper` or `residuum_fn` is needed on the fly (and hence also does not
have to be differentiable).

### One-Step supervised

<p align="center">
    <img src="https://ceyron.github.io/predictor-learning-setups/sup-1-none-true-primal.svg" width="400">
</p>

In [2]:
# The default is one-step supervised learning
tx.configuration.Supervised()

Supervised(
  num_rollout_steps=1,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  time_level_weights=f32[1]
)

### Two-Step supervised (rollout) Training

<p align="center">
    <img src="https://ceyron.github.io/predictor-learning-setups/sup-2-none-true-primal.svg" width="400">
</p>

We roll out the neural emulator for two autoregressive steps. Its parameters are
shared between the two predictions. Similarly, the `ref_stepper` is used to
create the reference trajectory; the loss is aggregated as a sum over the two
time levels.

In [3]:
tx.configuration.Supervised(num_rollout_steps=2)

Supervised(
  num_rollout_steps=2,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  time_level_weights=f32[2]
)

### Three-Step supervised (rollout) Training

<p align="center">
    <img src="https://ceyron.github.io/predictor-learning-setups/sup-3-none-true-primal.svg" width="400">
</p>

Same idead as above but with an additional rollout step.

In [4]:
tx.configuration.Supervised(num_rollout_steps=3)

Supervised(
  num_rollout_steps=3,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  time_level_weights=f32[3]
)

### Three-Step supervised (rollout) Training with loss only at final state

<p align="center">
    <img src="https://ceyron.github.io/predictor-learning-setups/sup-3-none-false-primal.svg" width="400">
</p>

The loss is only taken from the last step. Essentially, this corresponds to
weighting the time levels with $[0, 0, 1]$, respectively. (More weighting
options are possible, of course.)

In [5]:
tx.configuration.Supervised(
    num_rollout_steps=3, time_level_weights=jnp.array([0.0, 0.0, 1.0])
)

Supervised(
  num_rollout_steps=3,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  time_level_weights=f32[3]
)

### Three-Step supervised (rollout) Training with no backpropagation through time

(Displays the primal evaluation together with the cotangent flow; grey dashed
line indicates a cutted gradient.)

<p align="center">
    <img src="https://fkoehler.site/predictor-learning-setups/sup-3-none-true-no_net_bptt.svg" width="400">
</p>

This interrupts a gradient flow backward over the autoregressive network
execution. Gradients can still flow into the parameter space.

In [6]:
tx.configuration.Supervised(num_rollout_steps=3, cut_bptt=True)

Supervised(
  num_rollout_steps=3,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=True,
  cut_bptt_every=1,
  time_level_weights=f32[3]
)

### Four Steps supervised (rollout) Training with sparse backpropagation through time

<p align="center">
    <img src="https://fkoehler.site/predictor-learning-setups/sup-4-none-true-cut_every_2_net_bptt.svg" width="700">
</p>

Only every second backpropagation step is allowed to flow through the network.

In [7]:
tx.configuration.Supervised(num_rollout_steps=4, cut_bptt=True, cut_bptt_every=2)

Supervised(
  num_rollout_steps=4,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=True,
  cut_bptt_every=2,
  time_level_weights=f32[4]
)

## Diverted Chain

### Two-Steps with branch length one

<p align="center">
    <img src="https://fkoehler.site/predictor-learning-setups/div-2-1-true-primal.svg" width="500">
</p>

The `ref_stepper` is not run autoregressively for two steps from the initial
condition but rather for one step, branching off from the main chain created by
the emulator.

In [8]:
# `num_rollout_steps` referse to the number of autoregressive steps performed by
# the neural emulator
tx.configuration.DivertedChainBranchOne(num_rollout_steps=2)

DivertedChainBranchOne(
  num_rollout_steps=2,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  cut_div_chain=False,
  time_level_weights=f32[2]
)

In [9]:
# Alternatively, the general interface can be used
tx.configuration.DivertedChain(num_rollout_steps=2, num_branch_steps=1)

DivertedChain(
  num_rollout_steps=2,
  num_branch_steps=1,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  cut_div_chain=False,
  time_level_weights=f32[2],
  branch_level_weights=f32[1]
)

### Three-steps with branch length one

<p align="center">
    <img src="https://fkoehler.site/predictor-learning-setups/div-3-1-true-primal.svg" width="600">
</p>

In [10]:
tx.configuration.DivertedChainBranchOne(num_rollout_steps=3)

DivertedChainBranchOne(
  num_rollout_steps=3,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  cut_div_chain=False,
  time_level_weights=f32[3]
)

### Four-steps with branch length one

<p align="center">
    <img src="https://fkoehler.site/predictor-learning-setups/div-4-1-true-primal.svg" width="700">
</p>

In [11]:
tx.configuration.DivertedChainBranchOne(num_rollout_steps=4)

DivertedChainBranchOne(
  num_rollout_steps=4,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  cut_div_chain=False,
  time_level_weights=f32[4]
)

### Three-steps with branch length two

<p align="center">
    <img src="https://fkoehler.site/predictor-learning-setups/div-3-2-true-primal.svg" width="600">
</p>

In [12]:
# Can only be done with the general interface
tx.configuration.DivertedChain(num_rollout_steps=3, num_branch_steps=2)

DivertedChain(
  num_rollout_steps=3,
  num_branch_steps=2,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  cut_div_chain=False,
  time_level_weights=f32[3],
  branch_level_weights=f32[2]
)

### Two-Steps with no differentiable physics

<p align="center">
    <img src="https://fkoehler.site/predictor-learning-setups/div-2-1-true-no_dp.svg" width="500">
</p>

In [13]:
tx.configuration.DivertedChainBranchOne(
    num_rollout_steps=2,
    cut_div_chain=True,
)

DivertedChainBranchOne(
  num_rollout_steps=2,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  cut_div_chain=True,
  time_level_weights=f32[2]
)

### Two-Steps with no backpropagation through time

<p align="center">
    <img src="https://fkoehler.site/predictor-learning-setups/div-2-1-true-no_net_bptt.svg" width="500">
</p>

In [14]:
tx.configuration.DivertedChainBranchOne(
    num_rollout_steps=2,
    cut_bptt=True,
)

DivertedChainBranchOne(
  num_rollout_steps=2,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=True,
  cut_bptt_every=1,
  cut_div_chain=False,
  time_level_weights=f32[2]
)

### Two-Steps with no backpropagation through time and no differentiable physics

<p align="center">
    <img src="https://fkoehler.site/predictor-learning-setups/div-2-1-true-no_dp-no_net_bptt.svg" width="500">
</p>

In [15]:
tx.configuration.DivertedChainBranchOne(
    num_rollout_steps=2,
    cut_bptt=True,
    cut_div_chain=True,
)

DivertedChainBranchOne(
  num_rollout_steps=2,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=True,
  cut_bptt_every=1,
  cut_div_chain=True,
  time_level_weights=f32[2]
)

## Mix-Chain

So far, `Trainax` only supports "post-physics" mixing, meaning that the main
chain is built by first performing a specified number of autoregressive network
steps, and then a specified number of `ref_stepper` steps.

The reference trajectory is always built by autoregressively unrolling the
`ref_stepper`.

### One-Step Network with one Step Physics

<p align="center">
    <img src="https://fkoehler.site/predictor-learning-setups/mix-1-1-true-primal.svg" width="500">
</p>


In [16]:
tx.configuration.MixChainPostPhysics(
    num_rollout_steps=1,
    num_post_physics_steps=1,
)

MixChainPostPhysics(
  num_rollout_steps=1,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  num_post_physics_steps=1,
  cut_bptt=False,
  cut_bptt_every=1,
  time_level_weights=f32[2]
)

### One-Step Network with one step physics and loss only at final state

<p align="center">
    <img src="https://fkoehler.site/predictor-learning-setups/mix-1-1-false-primal.svg" width="500">
</p>

Similar to the supervised setting, this is achieved by choosing proper
`time_level_weights`. For `MixChainPostPhysics` the `time_level_weights` refer
to the entire main chain, i.e., the trajectory created by the former network
steps and the latter physics steps.

In [17]:
tx.configuration.MixChainPostPhysics(
    num_rollout_steps=1,
    num_post_physics_steps=1,
    time_level_weights=jnp.array([0.0, 1.0]),
)

MixChainPostPhysics(
  num_rollout_steps=1,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  num_post_physics_steps=1,
  cut_bptt=False,
  cut_bptt_every=1,
  time_level_weights=f32[2]
)

### Two-Step Network with one step physics

<p align="center">
    <img src="https://fkoehler.site/predictor-learning-setups/mix-2-1-true-primal.svg" width="600">
</p>

In [18]:
tx.configuration.MixChainPostPhysics(
    num_rollout_steps=1,
    num_post_physics_steps=2,
)

MixChainPostPhysics(
  num_rollout_steps=1,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  num_post_physics_steps=2,
  cut_bptt=False,
  cut_bptt_every=1,
  time_level_weights=f32[3]
)

### Two-Step Network with one step physics and no backpropagation through time

<p align="center">
    <img src="https://fkoehler.site/predictor-learning-setups/mix-2-1-true-no_net_bptt.svg" width="600">
</p>

In [19]:
tx.configuration.MixChainPostPhysics(
    num_rollout_steps=1,
    num_post_physics_steps=2,
    cut_bptt=True,
)

MixChainPostPhysics(
  num_rollout_steps=1,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  num_post_physics_steps=2,
  cut_bptt=True,
  cut_bptt_every=1,
  time_level_weights=f32[3]
)

# Residuum

Instead of having a `ref_stepper` that can be unrolled autoregressively, these
configurations rely on a `residuum_fn` that defines a condition based on two
consecutive time levels.

### One-Step Residuum

<p align="center">
    <img src="https://fkoehler.site/predictor-learning-setups/res-1-none-false-primal.svg" width="350">
</p>

In [20]:
tx.configuration.Residuum(
    num_rollout_steps=1,
)

Residuum(
  num_rollout_steps=1,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  cut_prev=False,
  cut_next=False,
  time_level_weights=f32[1]
)

### Two Steps Residuum Training

<p align="center">
    <img src="https://fkoehler.site/predictor-learning-setups/res-2-none-true-primal.svg" width="450">
</p>

In [21]:
tx.configuration.Residuum(
    num_rollout_steps=2,
)

Residuum(
  num_rollout_steps=2,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  cut_prev=False,
  cut_next=False,
  time_level_weights=f32[2]
)

### Three Steps Residuum Training

<p align="center">
    <img src="https://fkoehler.site/predictor-learning-setups/res-3-none-true-primal.svg" width="550">
</p>

In [22]:
tx.configuration.Residuum(
    num_rollout_steps=3,
)

Residuum(
  num_rollout_steps=3,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  cut_prev=False,
  cut_next=False,
  time_level_weights=f32[3]
)

### Three Steps Residuum with no backpropagation through time

<p align="center">
    <img src="https://fkoehler.site/predictor-learning-setups/res-3-none-true-no_net_bptt.svg" width="550">
</p>

In [23]:
tx.configuration.Residuum(num_rollout_steps=3, cut_bptt=True)

Residuum(
  num_rollout_steps=3,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=True,
  cut_bptt_every=1,
  cut_prev=False,
  cut_next=False,
  time_level_weights=f32[3]
)

### Other Residuum Options

It is possible to cut the `prev` and `next` contribution to the `residuum_fn`.

# Teacher Forcing

Resets the main chain with information from the autoregressive reference chain.
It is essentially the opposite of diverted chain learning.

It has similarities as if one selected minibatches over the entire trajectories.
However, this setup guarantees that within one gradient update, multiple
consecutive time levels are considered without having the network to rollout
autoregressively.

### Three Steps teacher forcing with reset every step

<p align="center">
    <img src="https://fkoehler.site/predictor-learning-setups/tf-3-1-true-primal.svg" width="550">
</p>

In [24]:
# TODO Implementation

### Four Steps teacher forcing with reset every second step

<p align="center">
    <img src="https://fkoehler.site/predictor-learning-setups/tf-4-2-true-primal.svg" width="750">
</p>

In [25]:
# TODO implementation

### Four Steps teacher forcing with reset every second step and no backpropagation through time

<p align="center">
    <img src="https://fkoehler.site/predictor-learning-setups/tf-4-2-true-no_net_bptt.svg" width="750">
</p>

In [26]:
# TODO implementation

# How about correction learning?

All the above mentioned setups are also usable for correction learning, i.e.,
when the emulator is not just a pure network but has some (differentiable)
(coarse) solver component. For example, in the case of sequential correction

<p align="center">
    <img src="https://fkoehler.site/corrector-configurations/sequential-corrector-primal.svg" width="350">
</p>

See [this](https://fkoehler.site/corrector-configurations/) websites for options
of potential corrector layouts and options to cut gradients within it.

All these layouts are **not** provided by `Trainax`. This is just to showcase
that the configurations of `Trainax` can be used in a more general context.