In [None]:
%pip install equinox

In [None]:
import jax
import jax.numpy as jnp
import equinox as eqx
import optax
import matplotlib.pyplot as plt
from typing import Callable, List
import scipy
from tqdm.autonotebook import tqdm

In [None]:
DATASET_PATH = "/kaggle/input/heat1d-dataset/heat1d_dataset.npy"
MODEL_PATH = "/kaggle/working/heat1d.eqx"

In [None]:
data = jnp.load(DATASET_PATH)

In [None]:
data.shape

In [None]:
#a is the initial condition
#u is the solution at t=1
y_ic, a, y_f = data[:, 0], data[:, 1], data[:, 2]
y_ic = y_ic[:, jnp.newaxis, :]
a = a[:, jnp.newaxis, :]
y_f = y_f[:, jnp.newaxis, :]

In [None]:
#Mesh is from 0 to 2pi
mesh = jnp.linspace(0, 2 * jnp.pi, a.shape[-1])

In [None]:
plt.plot(mesh, y_ic[0][0], label="Initial condition")
plt.plot(mesh, a[0][0], label="Alpha")
plt.plot(mesh, y_f[0][0], label="At t=1")
plt.legend()
plt.grid()

In [None]:
#Add the channels to the input and output
mesh_shape_corrected = jnp.repeat(mesh[jnp.newaxis, jnp.newaxis, :], a.shape[0], axis=0)
input_with_mesh = jnp.concatenate((mesh_shape_corrected, y_ic, a), axis=1)

In [None]:
input_with_mesh.shape

In [None]:
TRAIN_SIZE = int(a.shape[0] * 0.7)

train_x, test_x = input_with_mesh[:TRAIN_SIZE], input_with_mesh[TRAIN_SIZE:]
train_y, test_y = y_f[:TRAIN_SIZE], y_f[TRAIN_SIZE:]

In [None]:
class SpectralConv1d(eqx.Module):
    real_weights: jax.Array
    imag_weights: jax.Array
    in_channels: int
    out_channels: int
    modes: int

    def __init__(
            self,
            in_channels,
            out_channels,
            modes,
            *,
            key
    ):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes = modes

        #If we have more channels, we must scale down initialized weights to reduce effective VC dimension
        scale = 1.0 / (in_channels * out_channels)

        real_key, imag_key = jax.random.split(key, 2)
        self.real_weights = jax.random.uniform(real_key, (in_channels, out_channels, modes), minval=-scale, maxval=scale)
        self.imag_weights = jax.random.uniform(imag_key, (in_channels, out_channels, modes), minval=-scale, maxval=scale)

    def complex_mult1d(self, x_hat, w):
        return jnp.einsum("iM,ioM->oM", x_hat, w)

    def __call__(self, x):
        _, spatial_points = x.shape

        x_hat = jnp.fft.rfft(x) #(in_channels, spatial_points//2 + 1) since a real-valued transform
        x_hat_under_modes = x_hat[:, :self.modes] #(in_channels, self.modes)
        weights = self.real_weights + 1j * self.imag_weights #Complex weights
        out_hat_under_modes = self.complex_mult1d(x_hat_under_modes, weights) #(out_channels, self.modes)

        out_hat = jnp.zeros((self.out_channels, x_hat.shape[-1]), dtype=x_hat.dtype) #(out_channels, spatial_points//2+1)
        out_hat = out_hat.at[:, :self.modes].set(out_hat_under_modes)

        out = jnp.fft.irfft(out_hat, n=spatial_points)

        return out

In [None]:
class FNOBlock1d(eqx.Module):
    spectral_conv: SpectralConv1d
    bypass_conv: eqx.nn.Conv1d
    activation: Callable

    def __init__(self, in_channels, out_channels, modes, activation, *, key):
        spectral_conv_key, bypass_conv_key = jax.random.split(key, 2)
        self.spectral_conv = SpectralConv1d(
            in_channels,
            out_channels,
            modes,
            key=spectral_conv_key
        )
        self.bypass_conv = eqx.nn.Conv1d(in_channels, out_channels, 1, key=bypass_conv_key)
        self.activation = activation

    def __call__(self, x):
        return self.activation(
            self.spectral_conv(x) + self.bypass_conv(x)
        )

In [None]:
class FNO1d(eqx.Module):
    lifting: eqx.nn.Conv1d
    fno_blocks: List[FNOBlock1d]
    projection: eqx.nn.Conv1d

    def __init__(self, in_channels, out_channels, modes, width, activation, n_blocks = 4, *, key):
        key, lifting_key = jax.random.split(key, 2)
        self.lifting = eqx.nn.Conv1d(in_channels, width, 1, key=lifting_key)

        self.fno_blocks = []
        for i in range(n_blocks):
            key, subkey = jax.random.split(key, 2)
            self.fno_blocks.append(
                FNOBlock1d(
                    width, width, modes, activation, key=subkey
                )
            )

        key, projection_key = jax.random.split(key, 2)
        self.projection = eqx.nn.Conv1d(width, out_channels, 1, key=projection_key)

    def __call__(self, x):
        x = self.lifting(x)

        for fno_block in self.fno_blocks:
            x = fno_block(x)

        x = self.projection(x)

        return x

In [None]:
fno = FNO1d(
    in_channels=3,
    out_channels=1,
    modes=16,
    width=64,
    activation=jax.nn.silu,
    n_blocks=4,
    key=jax.random.PRNGKey(0)
)

In [None]:
def dataloader(
        key,
        dataset_x,
        dataset_y,
        batch_size
):
    n_samples = dataset_x.shape[0]
    n_batches = int(jnp.ceil(n_samples / batch_size))

    permutation = jax.random.permutation(key, n_samples)

    for batch_id in range(n_batches):
        start = batch_id * batch_size
        end = min((batch_id + 1) * batch_size, n_samples)

        batch_indices = permutation[start:end]

        yield dataset_x[batch_indices], dataset_y[batch_indices]

In [None]:
def loss_fn(model, x, y):
    y_pred = jax.vmap(model)(x)
    loss = jnp.mean((y_pred - y) ** 2)

    return loss

optimizer = optax.adabelief(1e-3)
opt_state = optimizer.init(eqx.filter(fno, eqx.is_array))

@eqx.filter_jit
def make_step(model, state, x, y):
    loss, grad = eqx.filter_value_and_grad(loss_fn)(model, x, y)
    val_loss = loss_fn(model, test_x, test_y)
    updates, new_state = optimizer.update(grad, state, model)
    new_model = eqx.apply_updates(model, updates)

    return new_model, new_state, loss, val_loss

loss_history = []
val_loss_history = []

shuffle_key = jax.random.PRNGKey(10)

for epoch in tqdm(range(1000)):
    shuffle_key, subkey = jax.random.split(shuffle_key, 2)
    for (batch_x, batch_y) in dataloader(
        subkey,
        train_x,
        train_y,
        batch_size=100
    ):
        fno, opt_state, loss, val_loss = make_step(fno, opt_state, batch_x, batch_y)
        loss_history.append(loss)
        val_loss_history.append(val_loss)

In [None]:
eqx.tree_serialise_leaves(MODEL_PATH, fno)

In [None]:
ax = plt.subplot()
ax.set_yscale('log')
ax.plot(val_loss_history)