# Correction Learning and Differentiable Physics

In this notebook, a neural emulator $f_\theta$ is trained to mimic the simulator
for the 1d advection equation $\mathcal{P}$. However, its receptive field is not
sufficient to to propagate the state at the given CFL number or difficulty
$\gamma_1$. To assist the neural emulator, a *corrected stepper* is built that
contains a defective numerical scheme $\tilde{\mathcal{P}}$ and the network's
task is to correct its output. Here, this defective scheme is only aware of
"half the difficulty" $\tilde{\gamma}_1 = \gamma_1/2$. However, if the defective
scheme already takes care of half the difficulty, the neural network only needs
half the receptive field to correct it (in the sequential setup).

We will train the corrected stepper both in a one-step mode which does not
require the defective/coarse physics to be differentiable, and in a supervised
rollout training requiring differentiable physics.

In [None]:
import jax
import jax.numpy as jnp
import equinox as eqx
import matplotlib.pyplot as plt
from typing import Callable
import optax
from tqdm.autonotebook import tqdm

In [None]:
import exponax as ex

In [None]:
from IPython.display import HTML

We will work in 1d, and choose a relatively coarse discretization. The problem
should be more or less agnostic to the spatial resolution as long all modes of
the initial conditions are resolved properly.

In [None]:
NUM_POINTS = 48
GAMMA_1 = 6.5

NUM_TRAIN_SAMPLES = 40
TRAIN_DATA_SEED = 773
TRAIN_TEMPORAL_HORIZON = 50

NUM_TEST_SAMPLES = 30
TEST_DATA_SEED = 774
TEST_TEMPORAL_HORIZON = 200

In [None]:
fine_stepper = ex.normalized.DiffultyLinearStepperSimple(
    1, NUM_POINTS, difficulty=-GAMMA_1, order=1
)
coarse_stepper = ex.normalized.DiffultyLinearStepperSimple(
    1, NUM_POINTS, difficulty=-GAMMA_1 / 2, order=1
)

The distribution of initial conditions is a truncated Fourier series with up to
5 modes. We limit its amplitude to 1 to ease plotting.

In [None]:
ic_distribution = ex.ic.RandomTruncatedFourierSeries(1, max_one=True)

Let's create the set of initial conditions out of it

In [None]:
train_ic_set = ex.build_ic_set(
    ic_distribution,
    num_points=NUM_POINTS,
    num_samples=NUM_TRAIN_SAMPLES,
    key=jax.random.PRNGKey(TRAIN_DATA_SEED),
)

Visualizing the initial states, we indeed see that those are a combination of
Fourier modes, each with the highest absolute value of 1.

In [None]:
ex.viz.plot_state_1d(train_ic_set[:3, 0, :]);

Rolling them out produces the train data set.

In [None]:
train_trj_set = jax.vmap(
    ex.rollout(fine_stepper, TRAIN_TEMPORAL_HORIZON, include_init=True)
)(train_ic_set)

In [None]:
# (NUM_TRAIN_SAMPLES, TRAIN_TEMPORAL_HORIZON + 1, 1, NUM_POINTS)
train_trj_set.shape

Let's quickly do the same for the test set to have it at hand

In [None]:
test_ic_set = ex.build_ic_set(
    ic_distribution,
    num_points=NUM_POINTS,
    num_samples=NUM_TEST_SAMPLES,
    key=jax.random.PRNGKey(TEST_DATA_SEED),
)

In [None]:
test_trj_set = jax.vmap(
    ex.rollout(fine_stepper, TEST_TEMPORAL_HORIZON, include_init=True)
)(test_ic_set)

In [None]:
# (NUM_TEST_SAMPLES, TEST_TEMPORAL_HORIZON + 1, 1, NUM_POINTS)
test_trj_set.shape

Let's visualize a couple of the train trajectories. Since we have a high CFL
number and such a low spatial resolution the spatio-temporal plot looks glitchy.

In [None]:
ex.viz.plot_spatio_temporal_facet(
    train_trj_set, facet_over_channels=False, figsize=(12, 6)
);

An animation more clearly shows that it is just advection happening. (It is just
very fast!!! ;)

In [None]:
HTML(ex.viz.animate_state_1d(train_trj_set[:, :3, 0, :]).to_jshtml())

## Training the neural emulator naively as predictor

Let's build a simple convolution ResNet with periodic padding

In [None]:
class ResBlockPeriodic1d(eqx.Module):
    conv_1: eqx.nn.Conv1d
    conv_2: eqx.nn.Conv1d
    activation: Callable

    def __init__(
        self,
        channels: int,
        activation: Callable,
        *,
        key,
    ):
        c_1_key, c_2_key = jax.random.split(key)
        self.conv_1 = eqx.nn.Conv1d(channels, channels, kernel_size=3, key=c_1_key)
        self.conv_2 = eqx.nn.Conv1d(channels, channels, kernel_size=3, key=c_2_key)
        self.activation = activation

    def periodic_padding(
        self,
        x,
    ):
        # padding over    channels space
        return jnp.pad(x, ((0, 0), (1, 1)), mode="wrap")

    def __call__(self, x):
        x_skip = x
        x = self.periodic_padding(x)
        x = self.conv_1(x)
        x = self.activation(x)
        x = self.periodic_padding(x)
        x = self.conv_2(x)
        x = x + x_skip
        x = self.activation(x)
        return x


class ResNetPeriodic1d(eqx.Module):
    lifting: eqx.nn.Conv1d
    blocks: tuple[ResBlockPeriodic1d]
    projection: eqx.nn.Conv1d

    def __init__(
        self,
        hidden_channels: int,
        num_blocks: int,
        activation: Callable,
        *,
        key,
    ):
        lifting_key, *block_keys, projection_key = jax.random.split(key, 2 + num_blocks)
        self.lifting = eqx.nn.Conv1d(1, hidden_channels, kernel_size=1, key=lifting_key)
        self.blocks = tuple(
            ResBlockPeriodic1d(hidden_channels, activation=activation, key=block_key)
            for block_key in block_keys
        )
        self.projection = eqx.nn.Conv1d(
            hidden_channels, 1, kernel_size=1, key=projection_key
        )

    def __call__(self, x):
        x = self.lifting(x)
        for block in self.blocks:
            x = block(x)
        x = self.projection(x)
        return x

In [None]:
def dataloader(
    data,
    *,
    batch_size: int,
    key,
):
    n_samples = data.shape[0]

    n_batches = int(jnp.ceil(n_samples / batch_size))

    permutation = jax.random.permutation(key, n_samples)

    for batch_id in range(n_batches):
        start = batch_id * batch_size
        end = min((batch_id + 1) * batch_size, n_samples)

        batch_indices = permutation[start:end]

        sub_data = data[batch_indices]

        yield sub_data


def cycling_dataloader(
    data,
    *,
    batch_size: int,
    num_steps: int,
    key,
    return_info: bool = False,
):
    epoch_id = 0
    total_step_id = 0

    while True:
        key, subkey = jax.random.split(key)

        for batch_id, sub_data in enumerate(
            dataloader(data, batch_size=batch_size, key=subkey)
        ):
            if total_step_id == num_steps:
                return

            if return_info:
                yield sub_data, epoch_id, batch_id
            else:
                yield sub_data

            total_step_id += 1

        epoch_id += 1

Let's train a ResNet with 2 blocks (of 2 3x convolutions each) on the train
dataset. We will train with one-step supervised learning, so we first have to
substack the data to have the input and output pairs by creating windows of size
two.

In [None]:
LEARNING_RATE = 3e-4
OPTIMIZER = optax.adam(LEARNING_RATE)
NUM_STEPS = 20_000
SHUFFLE_KEY = jax.random.PRNGKey(99)
BATCH_SIZE = 16


one_substacked_train_trj_set = jax.vmap(
    ex.stack_sub_trajectories,
    in_axes=(0, None),
)(train_trj_set, 2)
# Merge the two batch axes
one_train_dataset = jnp.concatenate(one_substacked_train_trj_set)

prediction_neural_emulator = ResNetPeriodic1d(
    32, 2, jax.nn.relu, key=jax.random.PRNGKey(0)
)

opt_state = OPTIMIZER.init(eqx.filter(prediction_neural_emulator, eqx.is_array))


def one_step_loss_fn(model, batch):
    x, y = batch[:, 0], batch[:, 1]
    y_hat = jax.vmap(model)(x)
    return jnp.mean((y - y_hat) ** 2)


@eqx.filter_jit
def step_fn(model, state, batch):
    loss, grads = eqx.filter_value_and_grad(one_step_loss_fn)(model, batch)
    updates, new_opt_state = OPTIMIZER.update(grads, state, model)
    new_model = eqx.apply_updates(model, updates)
    return new_model, new_opt_state, loss


shuffle_key = SHUFFLE_KEY
train_loss_history = []

p_meter = tqdm(total=NUM_STEPS)

for batch in cycling_dataloader(
    one_train_dataset, batch_size=BATCH_SIZE, num_steps=NUM_STEPS, key=shuffle_key
):
    prediction_neural_emulator, opt_state, loss = step_fn(
        prediction_neural_emulator, opt_state, batch
    )
    train_loss_history.append(loss)
    p_meter.update(1)

In [None]:
plt.semilogy(train_loss_history)

Let's use the final network state to make a prediction trajectory on all the
test initial states.

In [None]:
prediction_trj = jax.vmap(
    ex.rollout(prediction_neural_emulator, TEST_TEMPORAL_HORIZON, include_init=True)
)(test_ic_set)

And compute the mean_nRMSE rollout

In [None]:
mean_nRMSE_trj = jax.vmap(ex.metrics.mean_nRMSE, in_axes=1)(
    prediction_trj, test_trj_set
)

We see that the rollout already diverges after two time steps. This is caused by
an insufficient receptive field.

In [None]:
plt.plot(jnp.arange(0, 200 + 1), mean_nRMSE_trj)
plt.ylim(-0.05, 1.05)
plt.xlabel("Time Step")
plt.ylabel("Mean nRMSE")

Just as a baseline, let us train a resnet with more blocks

In [None]:
LEARNING_RATE = 3e-4
OPTIMIZER = optax.adam(LEARNING_RATE)
NUM_STEPS = 20_000
SHUFFLE_KEY = jax.random.PRNGKey(99)
BATCH_SIZE = 16


one_substacked_train_trj_set = jax.vmap(
    ex.stack_sub_trajectories,
    in_axes=(0, None),
)(train_trj_set, 2)
# Merge the two batch axes
one_train_dataset = jnp.concatenate(one_substacked_train_trj_set)

# Below we increased the number of blocks from 2 to 4
prediction_neural_emulator_more_reception = ResNetPeriodic1d(
    32, 4, jax.nn.relu, key=jax.random.PRNGKey(0)
)

opt_state = OPTIMIZER.init(
    eqx.filter(prediction_neural_emulator_more_reception, eqx.is_array)
)


def one_step_loss_fn(model, batch):
    x, y = batch[:, 0], batch[:, 1]
    y_hat = jax.vmap(model)(x)
    return jnp.mean((y - y_hat) ** 2)


@eqx.filter_jit
def step_fn(model, state, batch):
    loss, grads = eqx.filter_value_and_grad(one_step_loss_fn)(model, batch)
    updates, new_opt_state = OPTIMIZER.update(grads, state, model)
    new_model = eqx.apply_updates(model, updates)
    return new_model, new_opt_state, loss


shuffle_key = SHUFFLE_KEY
train_loss_history = []

p_meter = tqdm(total=NUM_STEPS)

for batch in cycling_dataloader(
    one_train_dataset, batch_size=BATCH_SIZE, num_steps=NUM_STEPS, key=shuffle_key
):
    prediction_neural_emulator_more_reception, opt_state, loss = step_fn(
        prediction_neural_emulator_more_reception, opt_state, batch
    )
    train_loss_history.append(loss)
    p_meter.update(1)

Notice the distinctly differently loss curve!

In [None]:
plt.semilogy(train_loss_history)

Let's again create the rollout and compute the mean_nRMSE

In [None]:
prediction_trj = jax.vmap(
    ex.rollout(
        prediction_neural_emulator_more_reception,
        TEST_TEMPORAL_HORIZON,
        include_init=True,
    )
)(test_ic_set)

In [None]:
mean_nRMSE_trj = jax.vmap(ex.metrics.mean_nRMSE, in_axes=1)(
    prediction_trj, test_trj_set
)

On a first glance, the rollout looks equally bad (but note that the limit of the x axis is different)

In [None]:
plt.plot(jnp.arange(0, 200 + 1), mean_nRMSE_trj)
plt.ylim(-0.05, 1.05)
plt.xlabel("Time Step")
plt.ylabel("Mean nRMSE")

Zooming in a bit, we see that the new predictor does not immediately explode.
It's performance is still not good, but at least better than before.

Feel free to play around with the number of blocks and other parameters!

In [None]:
plt.plot(jnp.arange(0, 200 + 1), mean_nRMSE_trj)
plt.ylim(-0.05, 1.05)
plt.xlim(0, 25)
plt.xlabel("Time Step")
plt.ylabel("Mean nRMSE")

## Correction Learning

Let's start by creating an `Equinox` wrapper that enables us to consider the
sequential corrector as one deep learning module.

In [None]:
class SequentialCorrector(eqx.Module):
    coarse_predictor: eqx.Module
    neural_corrector: eqx.Module

    def __call__(
        self,
        x,
    ):
        # We have to detach the coarse predictor to **not** have its parameters changed
        coarse_predictor_detached = jax.lax.stop_gradient(self.coarse_predictor)

        coarse_prediction = coarse_predictor_detached(x)
        corrected_prediction = self.neural_corrector(coarse_prediction)

        return corrected_prediction

Let's use again the ResNet with two blocks; now as a corrector network and train
the composite module similarly to before.

In [None]:
LEARNING_RATE = 3e-4
OPTIMIZER = optax.adam(LEARNING_RATE)
NUM_STEPS = 20_000
SHUFFLE_KEY = jax.random.PRNGKey(99)
BATCH_SIZE = 16


one_substacked_train_trj_set = jax.vmap(
    ex.stack_sub_trajectories,
    in_axes=(0, None),
)(train_trj_set, 2)
# Merge the two batch axes
one_train_dataset = jnp.concatenate(one_substacked_train_trj_set)

# Again using only two blocks
correcter_network = ResNetPeriodic1d(32, 2, jax.nn.relu, key=jax.random.PRNGKey(0))
corrected_stepper = SequentialCorrector(coarse_stepper, correcter_network)

opt_state = OPTIMIZER.init(eqx.filter(corrected_stepper, eqx.is_array))


def one_step_loss_fn(model, batch):
    x, y = batch[:, 0], batch[:, 1]
    y_hat = jax.vmap(model)(x)
    return jnp.mean((y - y_hat) ** 2)


@eqx.filter_jit
def step_fn(model, state, batch):
    loss, grads = eqx.filter_value_and_grad(one_step_loss_fn)(model, batch)
    updates, new_opt_state = OPTIMIZER.update(grads, state, model)
    new_model = eqx.apply_updates(model, updates)
    return new_model, new_opt_state, loss


shuffle_key = SHUFFLE_KEY
train_loss_history = []

p_meter = tqdm(total=NUM_STEPS)

for batch in cycling_dataloader(
    one_train_dataset, batch_size=BATCH_SIZE, num_steps=NUM_STEPS, key=shuffle_key
):
    corrected_stepper, opt_state, loss = step_fn(corrected_stepper, opt_state, batch)
    train_loss_history.append(loss)
    p_meter.update(1)

Note that the training loss history now looks similar to the ones we got for the
predictor with sufficient receptive field.

In [None]:
plt.semilogy(train_loss_history)

In [None]:
prediction_trj = jax.vmap(
    ex.rollout(corrected_stepper, TEST_TEMPORAL_HORIZON, include_init=True)
)(test_ic_set)

In [None]:
mean_nRMSE_trj = jax.vmap(ex.metrics.mean_nRMSE, in_axes=1)(
    prediction_trj, test_trj_set
)

In [None]:
plt.plot(jnp.arange(0, 200 + 1), mean_nRMSE_trj)
plt.ylim(-0.05, 1.05)
plt.xlabel("Time Step")
plt.ylabel("Mean nRMSE")

See that we are now even better than the predictor with more blocks!

In [None]:
plt.plot(jnp.arange(0, 200 + 1), mean_nRMSE_trj)
plt.ylim(-0.05, 1.05)
plt.xlabel("Time Step")
plt.ylabel("Mean nRMSE")
plt.xlim(0, 25)

## Rollout training

It turns out that corrected stepper greatly benefit from rollout training! Let's
do a training with five autoregressive supervised steps. This also requires
setting **windows of length 6**. This will also slightly reduce the number of
samples available per epoch. However, we keep that number of update steps fixed,
so we automatically compensate for this.

The training will be slightly longer because of the additional computation per
update step.

In [None]:
LEARNING_RATE = 3e-4
OPTIMIZER = optax.adam(LEARNING_RATE)
NUM_STEPS = 20_000
SHUFFLE_KEY = jax.random.PRNGKey(99)
BATCH_SIZE = 16


one_substacked_train_trj_set = jax.vmap(
    ex.stack_sub_trajectories,
    in_axes=(0, None),
)(
    train_trj_set, 4
)  # ! HERE WE USE 4
# Merge the two batch axes
one_train_dataset = jnp.concatenate(one_substacked_train_trj_set)

# Again using only two blocks
correcter_network = ResNetPeriodic1d(32, 2, jax.nn.relu, key=jax.random.PRNGKey(0))
corrected_stepper_rollout_trained = SequentialCorrector(
    coarse_stepper, correcter_network
)

opt_state = OPTIMIZER.init(eqx.filter(corrected_stepper_rollout_trained, eqx.is_array))


def one_step_loss_fn(model, batch):
    ic, ref_trj = batch[:, 0], batch[:, 1:]
    pred = ic
    loss = 0.0
    for i in range(
        3
    ):  # ! HERE WE USE 3 for three steps autoregressive rollout during training
        pred = jax.vmap(model)(pred)
        ref = ref_trj[:, i]
        loss += jnp.mean((ref - pred) ** 2)

    return loss


@eqx.filter_jit
def step_fn(model, state, batch):
    loss, grads = eqx.filter_value_and_grad(one_step_loss_fn)(model, batch)
    updates, new_opt_state = OPTIMIZER.update(grads, state, model)
    new_model = eqx.apply_updates(model, updates)
    return new_model, new_opt_state, loss


shuffle_key = SHUFFLE_KEY
train_loss_history = []

p_meter = tqdm(total=NUM_STEPS)

for batch in cycling_dataloader(
    one_train_dataset, batch_size=BATCH_SIZE, num_steps=NUM_STEPS, key=shuffle_key
):
    corrected_stepper_rollout_trained, opt_state, loss = step_fn(
        corrected_stepper_rollout_trained, opt_state, batch
    )
    train_loss_history.append(loss)
    p_meter.update(1)

Note that the rollout loss level will be different from the one-step loss level
by a factor of five because we simply added up all time-level losses.

In [None]:
plt.semilogy(train_loss_history)

In [None]:
prediction_trj = jax.vmap(
    ex.rollout(
        corrected_stepper_rollout_trained, TEST_TEMPORAL_HORIZON, include_init=True
    )
)(test_ic_set)

In [None]:
mean_nRMSE_trj = jax.vmap(ex.metrics.mean_nRMSE, in_axes=1)(
    prediction_trj, test_trj_set
)

In [None]:
plt.plot(jnp.arange(0, 200 + 1), mean_nRMSE_trj)
plt.ylim(-0.05, 1.05)
plt.xlabel("Time Step")
plt.ylabel("Mean nRMSE")

In [None]:
plt.plot(jnp.arange(0, 200 + 1), mean_nRMSE_trj)
plt.ylim(-0.05, 1.05)
plt.xlabel("Time Step")
plt.ylabel("Mean nRMSE")
plt.xlim(0, 25)