In [2]:
import jax
import jax.numpy as jnp
import equinox as eqx
from typing import Callable, List

In [29]:
class SpectralConv2d(eqx.Module):
    real_weights: jax.Array
    imag_weights: jax.Array
    in_channels: int
    out_channels: int
    modes_x: int
    modes_y: int

    def __init__(
            self,
            in_channels,
            out_channels,
            modes_x,
            modes_y,
            *,
            key,
    ):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes_x = modes_x
        self.modes_y = modes_y


        scale = 1.0 / (in_channels * out_channels)                      # check if this might need change

        real_key, imag_key = jax.random.split(key)
        self.real_weights = jax.random.uniform(
            real_key,
            (in_channels, out_channels, modes_x, modes_y),
            minval=-scale,
            maxval=+scale,
        )
        self.imag_weights = jax.random.uniform(
            imag_key,
            (in_channels, out_channels, modes_x, modes_y),
            minval=-scale,
            maxval=+scale,
        )    
    def complex_mult1d(
            self,
            x_hat,
            w,
    ):
        return jnp.einsum("iXY,ioXY->oXY", x_hat, w)  
    
    def __call__(
            self,
            x,
    ):
        channels, spatial_points_x, spatial_points_y = x.shape

        x_hat = jnp.fft.rfft2(x)                                            # ergänzen das axis 1 und 2?
        x_hat_under_modes = x_hat[:, :self.modes_x, :self.modes_y]
        weights = self.real_weights + 1j * self.imag_weights
        out_hat_under_modes = self.complex_mult1d(x_hat_under_modes, weights)

        out_hat = jnp.zeros(
            (self.out_channels, *x_hat.shape[1:]),
            dtype=x_hat.dtype
        )
        
        out_hat = out_hat.at[:, :self.modes_x, :self.modes_y].set(out_hat_under_modes)

        out = jnp.fft.irfft2(out_hat, s=[spatial_points_x, spatial_points_y])     #previously: n=spatial_points)

        return out


In [30]:
class FNOBlock1d(eqx.Module):
    spectral_conv: SpectralConv2d
    bypass_conv: eqx.nn.Conv2d
    activation: Callable

    def __init__(
            self,
            in_channels,
            out_channels,
            modes_x,
            modes_y,
            activation,
            *,
            key,
    ):
        spectral_conv_key, bypass_conv_key = jax.random.split(key)
        self.spectral_conv = SpectralConv2d(
            in_channels,
            out_channels,
            modes_x,   # hier evtl in Klammern 
            modes_y,
            key=spectral_conv_key,
        )

        self.bypass_conv = eqx.nn.Conv2d(
            in_channels,
            out_channels,
            1,  # Kernel size is one
            key=bypass_conv_key,
        )
        self.activation = activation

    def __call__(
            self,
            x,
    ):
        return self.activation(
            self.spectral_conv(x) + self.bypass_conv(x)
        )
    

In [None]:
# testing

In [None]:
class FNO1d(eqx.Module):
    lifting: eqx.nn.Conv1d
    fno_blocks: List[FNOBlock1d]
    projection: eqx.nn.Conv1d

    def __init__(
            self,
            in_channels,
            out_channels,
            modes,
            width,
            activation,
            n_blocks = 4,
            *,
            key,
    ):
        key, lifting_key = jax.random.split(key)
        #lifting erhöht channel dim aber nicht spatial dim
        self.lifting = eqx.nn.Conv1d(
            in_channels,
            width,
            1,
            key=lifting_key,
        )

        self.fno_blocks = []
        for i in range(n_blocks):
            key, subkey = jax.random.split(key)  #bedeutet das, jeder Block wird gleich initialisiert, weil immer gleicher key?
            self.fno_blocks.append(FNOBlock1d(
                width,
                width,
                modes,
                activation,
                key=subkey,
            ))
        #projection umgekehrt zu lifting
        key, projection_key = jax.random.split(key)
        self.projection = eqx.nn.Conv1d(
            width,
            out_channels,
            1,
            key=projection_key,
        )
    def __call__(
            self,
            x,
    ):
        x = self.lifting(x)

        for fno_block in self.fno_blocks:
            x = fno_block(x)

        x = self.projection(x)

        return x