# 2D FNO Exercise

In [None]:
# %config InlineBackend.figure_format = 'svg'
import random
import torch
import sys
import numpy as np
import math
import matplotlib.pyplot as plt
from timeit import default_timer
from scipy.io import loadmat

sys.path.append('../')
from models import FNN1d, FNN_train, construct_model, compute_1dFourier_bases, compute_2dFourier_bases
torch.set_printoptions(precision=16)


# Load data

In [None]:
data_path = "../mytest/data/piececonst_r421_N1024_smooth1"
data = loadmat(data_path)

In [None]:
data_in = data["coeff"]
data_out = data["sol"]
print("input date size: ", data_in.shape, "output data size: ", data_out.shape)

L = 1 
Np_ref = data_in.shape[1]
grid_1d = np.linspace(0, L, Np_ref)

grid_x, grid_y = np.meshgrid(grid_1d, grid_1d)
fig, axs = plt.subplots(1, 2, figsize=(6,3))
axs[0].pcolormesh(grid_x, grid_y, data_in[0,:,:])
axs[1].pcolormesh(grid_x, grid_y, data_out[0,:,:])

# 2D Fourier Transform

We compute the Fourier transform for the last 2 dimensions of x. We first define the mode set 

\begin{align*}
K_x = \{k_x | k_x = 0,1,...,n_x/2-1, -n_x/2, -n_x/2+1, ... -1 \} \\
K_y = \{k_y | k_y = 0,1,...,n_y/2-1, -n_y/2, -n_y/2+1, ... -1 \} 
\end{align*}

Then the Fourier transform and inverse Fourier transform give the relation between 
$\{\hat{f}[k_x, k_y]: k_x \in K_x, k_y \in K_y \}$ and $\{f[j_x, j_y] : 0 \leq j_x \leq n_x, 0 \leq j_y \leq n_y\}$

\begin{align*}
   f(x,y) &= \frac{1}{n_x n_y} \sum_{k_x \in K_x, k_y \in K_y}  \hat{f}[k_x, k_y]  e^{i k_x \frac{2\pi x}{L_x} + i k_y \frac{2\pi y}{L_y}} 
   \\ 
   f[j_x, j_y] &= \frac{1}{n_x n_y} \sum_{k_x \in K_x, k_y \in K_y}  \hat{f}[k_x, k_y]  e^{i k_x \frac{2\pi j_x \Delta x}{L_x} + i k_y \frac{2\pi j_y \Delta y}{L_y}} 
   \\ 
   \hat{f}[k_x, k_y] &= \frac{n_x n_y}{L_x L_y} \int f(x, y)  e^{-i k_x \frac{2\pi x}{L_x} - i k_y \frac{2\pi y}{L_y}} dx dy\\
                &\approx \frac{n_x n_y}{L_x L_y}  \sum_{j_x = 0}^{n_x - 1}\sum_{j_y = 0}^{n_y - 1} f[j_x, j_y]  e^{-i k_x \frac{2\pi j_x \Delta x }{L_x} - i k_y \frac{2\pi j_y \Delta y }{L_y}} \Delta x \Delta y \qquad \textrm{when f has certain form, this is accurate.}\\
                &= \sum_{j_x = 0}^{n_x - 1}\sum_{j_y = 0}^{n_y - 1} f[j_x, j_y]  e^{-i k_x \frac{2\pi  j_x}{n_x}-i k_y \frac{2\pi  j_y}{n_y}}
\end{align*}

When $f(x)$ is real, we have $\hat{f}[k_x, k_y] = conj(\hat{f}[-k_x, -k_y])$, and hence `rfftn` only need to save the subset of $K$ 

\begin{align*}
&K_x^r = \{k_x | k_x = 0,1,...,n_x/2-1, -n_x/2, -n_x/2+1, ... -1\} \\
&K_y^r = \{k_y | k_y = 0,1,...,n_y/2-1, -n_y/2\} 
\end{align*}


In FNO, when we truncate to first $k_{xf}$ and $k_{yf}$ modes, we keep
\begin{align*}
& \{(k_x, k_y) | 
k_x = 0,1,...,k_{xf}-1,  -k_{xf}+1, ... -1, 
k_y = 0,1,...,k_{yf}-1\} 
\end{align*}
In the original FNO implementation, $k_x = -k_{xf}$ is also included.

\begin{align*}
f(x,y) 
&= \frac{1}{n_x n_y}  
Re\Bigl(
  \sum_{k_x=-k_{xf}+1}^{k_{xf}-1} \sum_{k_y=0}^{k_{yf}-1} \hat{f}[k_x, k_y]  \phi_{k_x, k_y}(x,y) 
+ \sum_{k_x=-k_{xf}+1}^{k_{xf}-1} \sum_{k_y=-k_{yf}-1}^{-1} \hat{f}[k_x, k_y]  \phi_{k_x, k_y}(x,y) 
\Bigr) 
\\
&= \frac{1}{n_x n_y} Re\Bigl(\sum_{k_x=-k_{xf}+1}^{k_{xf}-1} \hat{f}[k_x, 0]  \phi_{k_x, 0}(x)  +  \sum_{k_x=-k_{xf}+1}^{k_{xf}-1}\sum_{k_y=1}^{k_{yf}-1}  (\hat{f}[k_x, k_y]  \phi_{k_x, k_y}(x,y) + \hat{f}[-k_x, -k_y]  \phi_{-k_x, -k_y}(x,y))\Bigr)
\\
&= \frac{1}{n_x n_y} Re\Bigl(\sum_{k_x=-k_{xf}+1}^{k_{xf}-1} \hat{f}[k_x, 0]  \phi_{k_x, 0}(x)  +  2\sum_{k_x=-k_{xf}+1}^{k_{xf}-1}\sum_{k_y=1}^{k_{yf}-1}  \hat{f}[k_x, k_y]  \phi_{k_x, k_y}(x,y) \Bigr)
\end{align*}



In [9]:
# Here is a test, be careful, the torch will generally use float32. 
# To get accurate error estimate, use float64

downsample_ratio = 20
n_train = 2**10
x_train = torch.from_numpy(np.stack((data_in[0:n_train, 0::downsample_ratio, 0::downsample_ratio], 
                           np.tile(grid_x[0::downsample_ratio,0::downsample_ratio], (n_train,1,1)), 
                           np.tile(grid_y[0::downsample_ratio,0::downsample_ratio], (n_train,1,1))), axis=-1))
batchsize = 16
x = x_train[0:batchsize]
# batch, channel, x, y
x = x.permute(0, 3, 1, 2)
x_ft = torch.fft.rfftn(x, dim=[2,3])
print(" The shape of x is ", x.shape, " the shape of x_ft is ", x_ft.shape)

n_b, n_c, n_x, n_y = x.shape
assert(n_x%2 == 0)
n_k = n_y//2

x = x.to(torch.complex128)


# Implementation 1
Kr_x = list(range(0, n_k)) + [-n_k] + list(range(-n_k+1, 0))
Kr_y = list(range(0, n_k)) + [-n_k]
print("len(Kr_x), len(Kr_y) = ", len(Kr_x), len(Kr_y))
x_ft2 = torch.zeros(n_b, n_c, n_x, n_k+1, dtype=torch.complex128)
for j_xk in range(len(Kr_x)):
    for j_yk in range(len(Kr_y)):
        basis = torch.einsum('x,y->xy',
        torch.exp((-Kr_x[j_xk] * 2*np.pi/n_x * 1.0j * torch.linspace(0, n_x-1, n_x, dtype=torch.float64))),
        torch.exp((-Kr_y[j_yk] * 2*np.pi/n_y * 1.0j * torch.linspace(0, n_y-1, n_y, dtype=torch.float64)))
        )
        x_ft2[:, :, j_xk, j_yk] += torch.einsum('bcxy,xy->bc', x, basis)
print("Error between x_ft and x_ft2 is ", torch.norm(x_ft - x_ft2))


# Implementation 2
Kr_x = torch.tensor(list(range(0, n_k)) + [-n_k] + list(range(-n_k+1, 0)), dtype=torch.float64)
Kr_y = torch.tensor(list(range(0, n_k)) + [-n_k], dtype=torch.float64)
bases = torch.einsum('xk,yl->xykl',
                     torch.exp( torch.outer(torch.linspace(0, n_x-1, n_x, dtype=torch.float64), -2*np.pi/n_x * 1.0j * Kr_x) ),
                     torch.exp( torch.outer(torch.linspace(0, n_y-1, n_y, dtype=torch.float64), -2*np.pi/n_y * 1.0j * Kr_y) )  
                    )
x_ft2 = torch.einsum('bcxy,xykl->bckl', x, bases)
print("Error between x_ft and x_ft2 is ", torch.norm(x_ft - x_ft2))

 The shape of x is  torch.Size([16, 3, 22, 22])  the shape of x_ft is  torch.Size([16, 3, 22, 12])
len(Kr_x), len(Kr_y) =  22 12
Error between x_ft and x_ft2 is  tensor(2.3064641066838742e-11, dtype=torch.float64)
Error between x_ft and x_ft2 is  tensor(2.8806366063147641e-11, dtype=torch.float64)


In [None]:
x_ift = torch.fft.irfftn(x_ft, dim=[2,3])
print("Error between x and x_ift is ", torch.norm(x - x_ift))

# For the irfftn, it takes the real part
x_ft_test = torch.clone(x_ft)
print(x_ft_test[0, 0, 2,0], x_ft_test[0,0,-2,0])
x_ft_test[0,0,0,0] += 10.0e10j
x_ft_test[0,0, -2,0] += 10e4j
x_ft_test[0,0,  2,0] += 10e4j
print(x_ft_test.shape)
x_ift = torch.fft.irfftn(x_ft_test, dim=[2,3])
print("Error between x and x_ift is ", torch.norm(x - x_ift))


In [10]:
x_ift = torch.fft.irfftn(x_ft, dim=[2, 3])
print("Error between x and x_ift is ", torch.norm(x - x_ift))

modes = 6
#  Truncate to the first m modes
x_ft_trunc = torch.zeros(n_b, n_c, n_x, n_k + 1, dtype=torch.complex128)
x_ft_trunc[:, :, :modes, :modes] = x_ft[:, :, :modes, :modes]
x_ft_trunc[:, :, -modes + 1 :, :modes] = x_ft[:, :, -modes + 1 :, :modes]

x_ift_trunc = torch.fft.irfftn(x_ft_trunc, dim=[2, 3])

# Implementation 1
Kr_x = torch.tensor(
    list(range(0, n_k)) + [-n_k] + list(range(-n_k + 1, 0)), dtype=torch.float64
)
Kr_y = torch.tensor(list(range(0, n_k)) + [-n_k], dtype=torch.float64)
bases = torch.einsum(
    "xk,yl->xykl",
    torch.exp(
        torch.outer(
            torch.linspace(0, n_x - 1, n_x, dtype=torch.float64),
            2 * np.pi / n_x * 1.0j * Kr_x,
        )
    ),
    torch.exp(
        torch.outer(
            torch.linspace(0, n_y - 1, n_y, dtype=torch.float64),
            2 * np.pi / n_y * 1.0j * Kr_y,
        )
    ),
)
bases[:, :, :, 1:-1] *= 2.0
x_ift2 = torch.real(torch.einsum("bckl,xykl->bcxy", x_ft_trunc, bases)) / (n_x * n_y)
print("Error between x and x_ift2 is ", torch.norm(x_ift_trunc - x_ift2))

# Low Rank Implementation
Kr_x = torch.tensor(
    list(range(0, modes)) + list(range(-modes + 1, 0)), dtype=torch.float64
)
Kr_y = torch.tensor(list(range(0, modes)), dtype=torch.float64)
Kr = torch.tensor(list(range(0, modes)), dtype=torch.float64)
bases = torch.einsum(
    "xk,yl->xykl",
    torch.exp(
        torch.outer(
            torch.linspace(0, n_x - 1, n_x, dtype=torch.float64),
            2 * np.pi / n_x * 1.0j * Kr_x,
        )
    ),
    torch.exp(
        torch.outer(
            torch.linspace(0, n_y - 1, n_y, dtype=torch.float64),
            2 * np.pi / n_y * 1.0j * Kr_y,
        )
    ),
)
bases[:, :, :, 1:] *= 2.0
x_ift2 = torch.real(
    torch.einsum(
        "bckl,xykl->bcxy",
        x_ft_trunc[:, :, list(range(0, modes)) + list(range(-modes + 1, 0)), :modes],
        bases,
    )
) / (n_x * n_y)
print("Error between x and x_ift2 is ", torch.norm(x_ift_trunc - x_ift2))

Error between x and x_ift is  tensor(1.3465297236977153e-13, dtype=torch.float64)
Error between x and x_ift2 is  tensor(3.5653540552682406e-13, dtype=torch.float64)
Error between x and x_ift2 is  tensor(3.5653540552682406e-13, dtype=torch.float64)


# Spectral Transform

We compute the spectral transform for the last dimension of x. 
We first define $n_k$ orthonal bases $\{ \phi_k \}_{k = 0}^{n_k-1}$


Then the spectral transform and inverse spectral transform give the relation between 
$\{\hat{f}[k]: k \in K \}$ and $\{f(x_j) : 0 \leq j \leq n_x\}$

\begin{align*}
   f(x) &= \sum_{k=0}^{n_k-1}  \hat{f}[k]  \phi_k(x) 
   \\ 
   f(x_j) &= \sum_{k_x=0}^{n_k-1}  \hat{f}[k]  \phi_k(x_j)
   \\ 
   \hat{f}[k] &= \int  f(x)  \phi_k(x) dx  \\
                &= \sum_{j = 0}^{n_x}  f(x_j)  \phi_k(x_j) \Delta x_j
\end{align*}
Here $\phi_k(x)$ are orthogonal bases with $$\int \phi_i(x) \overline{\phi_k(x)} dx = \delta_{ik},$$ and at the discrete level
$$\sum_{j = 0}^{n_x} \phi_i(x_j) \overline{\phi_j(x_j)} \Delta x_j = \delta_{ik}$$



We set bases 
\begin{equation}
B = 
\begin{bmatrix}
\phi_0(x_j)  & \phi_1(x_j)  &  \cdots &  \phi_{n_k-1}(x_j)  
\end{bmatrix}
\end{equation}
and weighted bases 
\begin{equation}
B = 
\begin{bmatrix}
\phi_0(x_j) \Delta x_j  & \phi_1(x_j) \Delta x_j &  \cdots &  \phi_{n_k-1}(x_j) \Delta x_j
\end{bmatrix}
\end{equation}

To recover 2D Fourier Transform, we set 
\begin{align*}
\phi_k(x) = \begin{cases}
\frac{a_{k}}{\sqrt{L_xL_y}}\sin(\frac{2\pi k_x x}{L_x} + \frac{2\pi k_y y}{L_y}) & k_y > 0 \cup  k_y=0 \, k_x \geq 0\\ 
\frac{a_{k}}{\sqrt{L_xL_y}}\cos(\frac{2\pi k_x x}{L_x} + \frac{2\pi k_y y}{L_y}) & 
\end{cases}
\end{align*}
with $a_{0,0} = 1$ and $a_{k} = \sqrt{2}$.


In [16]:
torch.manual_seed(0)
np.random.seed(0)

downsample_ratio = 4

L = 1.0
Np_ref = data_in.shape[1]
grid_1d = np.linspace(0, L, Np_ref)
grid_x, grid_y = np.meshgrid(grid_1d, grid_1d)


n_train, n_test = 800, 200
data_in_ds = data_in[0:n_train, 0::downsample_ratio, 0::downsample_ratio]
grid_x_ds = grid_x[0::downsample_ratio, 0::downsample_ratio]
grid_y_ds = grid_y[0::downsample_ratio, 0::downsample_ratio]
data_out_ds = data_out[0:n_train, 0::downsample_ratio, 0::downsample_ratio]


x_train = torch.from_numpy(
    np.stack(
        (
            data_in_ds,
            np.tile(grid_x_ds, (n_train, 1, 1)),
            np.tile(grid_y_ds, (n_train, 1, 1)),
        ),
        axis=-1,
    ).astype(np.float32)
)
y_train = torch.from_numpy(data_out_ds[:, :, :, np.newaxis].astype(np.float32))


x_test = torch.from_numpy(
    np.stack(
        (
            data_in[-n_test:, 0::downsample_ratio, 0::downsample_ratio],
            np.tile(grid_x[0::downsample_ratio, 0::downsample_ratio], (n_test, 1, 1)),
            np.tile(grid_y[0::downsample_ratio, 0::downsample_ratio], (n_test, 1, 1)),
        ),
        axis=-1,
    ).astype(np.float32)
)
y_test = torch.from_numpy(
    data_out[-n_test:, 0::downsample_ratio, 0::downsample_ratio, np.newaxis].astype(
        np.float32
    )
)


n_fno_layers = 3
k_max = 33
d_f = 128
# fourier k_max
modes = [k_max] * n_fno_layers
# channel d_f
layers = [d_f] * (n_fno_layers + 1)
fc_dim = d_f
in_dim = 3
out_dim = 1
act = "gelu"

epochs = 1000
base_lr = 0.001


milestones = [200, 300, 400, 500, 800, 900]
scheduler_gamma = 0.5
pad_ratio = 0.05
scheduler = "MultiStepLR"
weight_decay = 1.0e-4
# batch_size=32
dim = 2

# scheduler = "CosineAnnealingLR"
# weight_decay = 1.0e-4
batch_size = 64

normalization_x = True
normalization_y = True
normalization_dim = []


basis_type = "Galerkin_bases"

if basis_type == "Fast_Fourier_Transform":

    k_max = k_max // 2
    modes = [k_max] * n_fno_layers

    bases = None
    wbases = None
    model_type = "FNO"
else:
    # reshape the data
    x_train = x_train.reshape(x_train.shape[0], -1, x_train.shape[-1])
    x_test = x_test.reshape(x_test.shape[0], -1, x_test.shape[-1])
    y_train = y_train.reshape(y_train.shape[0], -1, y_train.shape[-1])
    y_test = y_test.reshape(y_test.shape[0], -1, y_test.shape[-1])

    model_type = "GalerkinNO"
    Np = (Np_ref + downsample_ratio - 1) // downsample_ratio
    gridx, gridy, fbases, weights = compute_2dFourier_bases(Np, Np, k_max, L, L)

    if basis_type == "Fourier_bases":
        fbases = fbases.reshape(-1, k_max)
        weights = weights.reshape(-1)
        wfbases = fbases * np.tile(weights, (k_max, 1)).T
        bases = [torch.from_numpy(fbases.astype(np.float32))]
        wbases = [torch.from_numpy(wfbases.astype(np.float32))]

    elif basis_type == "Galerkin_bases":
        pca_data = data_out_ds.reshape((data_out_ds.shape[0], -1))
        pca_include_input = False
        pca_include_grid = False
        if pca_include_input:
            pca_data = np.vstack(
                (pca_data, data_in_ds.reshape((data_in_ds.shape[0], -1)))
            )
        if pca_include_grid:
            n_grid = 1
            pca_data = np.vstack((pca_data, np.tile(grid_x_ds, (n_grid, 1))))
            pca_data = np.vstack((pca_data, np.tile(grid_y_ds, (n_grid, 1))))
        print("Start SVD with data shape: ", pca_data.shape)
        U, S, VT = np.linalg.svd(pca_data.T, full_matrices=False)
        print(U.shape, S.shape, VT.shape)
        # the integration of the basis is 1.
        fbases = U[:, 0:k_max] / np.sqrt(L * L / Np**2)
        wfbases = L * L / Np**2 * fbases
        print(fbases.shape,wfbases.shape)
        bases = [torch.from_numpy(fbases.astype(np.float32))]
        wbases = [torch.from_numpy(wfbases.astype(np.float32))]

    else:

        print("Bases construction error")

Start SVD with data shape:  (800, 11236)
(11236, 800) (800,) (800, 800)
(11236, 33) (11236, 33)


In [None]:
config = {
    "model": {
        "model": model_type,
        "dim": dim,
        "modes": modes,
        "fc_dim": fc_dim,
        "layers": layers,
        "in_dim": in_dim,
        "out_dim": out_dim,
        "act": act,
        "pad_ratio": pad_ratio,
    },
    "train": {
        "base_lr": base_lr,
        "weight_decay": weight_decay,
        "epochs": epochs,
        "scheduler": scheduler,
        "milestones": milestones,
        "scheduler_gamma": scheduler_gamma,
        "batch_size": batch_size,
        "normalization_x": normalization_x,
        "normalization_y": normalization_y,
        "normalization_dim": normalization_dim,
    },
}


model = construct_model(config, bases, wbases)

In [None]:
print("Start training ", config["model"]["model"])
print("x_train shape: ", x_train.shape, "y_train shape: ", y_train.shape)
print("x_test shape: ", x_train.shape, "y_test shape: ", y_train.shape)

train_rel_l2_losses, test_rel_l2_losses, test_l2_losses, cost = FNN_train(
    x_train, y_train, x_test, y_test, config, model, save_model_name="models/test"
)

In [None]:
Start training  GalerkinNO
Epoch :  0  Rel. Train L2 Loss :  0.5530278962105513  Rel. Test L2 Loss :  0.3639783691614866  Test L2 Loss :  0.16688358411192894
Epoch :  10  Rel. Train L2 Loss :  0.05540479300543666  Rel. Test L2 Loss :  0.06010106788016856  Test L2 Loss :  0.027525703771971166
Epoch :  20  Rel. Train L2 Loss :  0.02729798946529627  Rel. Test L2 Loss :  0.030181090580299497  Test L2 Loss :  0.014549962885212153
Epoch :  30  Rel. Train L2 Loss :  0.023524480871856213  Rel. Test L2 Loss :  0.02733145747333765  Test L2 Loss :  0.012954483041539788
Epoch :  40  Rel. Train L2 Loss :  0.03390151506755501  Rel. Test L2 Loss :  0.032641293248161674  Test L2 Loss :  0.014019099180586636
Epoch :  50  Rel. Train L2 Loss :  0.028370561194606125  Rel. Test L2 Loss :  0.02992247836664319  Test L2 Loss :  0.013746349257417023
Epoch :  60  Rel. Train L2 Loss :  0.016377027553971857  Rel. Test L2 Loss :  0.02047357289120555  Test L2 Loss :  0.009557288809446618
Epoch :  70  Rel. Train L2 Loss :  0.016247911378741264  Rel. Test L2 Loss :  0.0221169067081064  Test L2 Loss :  0.009881779027637094
Epoch :  80  Rel. Train L2 Loss :  0.018533223308622837  Rel. Test L2 Loss :  0.019773047999478877  Test L2 Loss :  0.009210543415974826
Epoch :  90  Rel. Train L2 Loss :  0.014880945847835392  Rel. Test L2 Loss :  0.0222996415104717  Test L2 Loss :  0.01034952379995957
Epoch :  100  Rel. Train L2 Loss :  0.023593697929754853  Rel. Test L2 Loss :  0.026125550735741854  Test L2 Loss :  0.012204635073430836
Epoch :  110  Rel. Train L2 Loss :  0.0163455773727037  Rel. Test L2 Loss :  0.019539905828423798  Test L2 Loss :  0.009263549611205235
Epoch :  120  Rel. Train L2 Loss :  0.012967587856110185  Rel. Test L2 Loss :  0.020988903241232038  Test L2 Loss :  0.009709200297947973
Epoch :  130  Rel. Train L2 Loss :  0.014892612001858652  Rel. Test L2 Loss :  0.021787052624858916  Test L2 Loss :  0.009926729835569859
Epoch :  140  Rel. Train L2 Loss :  0.022582464618608356  Rel. Test L2 Loss :  0.03254837263375521  Test L2 Loss :  0.014737804944161326
Epoch :  150  Rel. Train L2 Loss :  0.017551998258568347  Rel. Test L2 Loss :  0.01922531088348478  Test L2 Loss :  0.008974029566161335
Epoch :  160  Rel. Train L2 Loss :  0.01702975877560675  Rel. Test L2 Loss :  0.01916457514744252  Test L2 Loss :  0.00896387847024016
Epoch :  170  Rel. Train L2 Loss :  0.01095853850711137  Rel. Test L2 Loss :  0.011898183031007648  Test L2 Loss :  0.005927028658334166
Epoch :  180  Rel. Train L2 Loss :  0.01998165505938232  Rel. Test L2 Loss :  0.026575087918899953  Test L2 Loss :  0.012266059755347669
Epoch :  190  Rel. Train L2 Loss :  0.010873873543459922  Rel. Test L2 Loss :  0.014671170676592737  Test L2 Loss :  0.007119098474504426
Epoch :  200  Rel. Train L2 Loss :  0.01028971589403227  Rel. Test L2 Loss :  0.013454765372443944  Test L2 Loss :  0.006566085590748116
Epoch :  210  Rel. Train L2 Loss :  0.006467006693128496  Rel. Test L2 Loss :  0.011402946896851063  Test L2 Loss :  0.0056951247388496995
Epoch :  220  Rel. Train L2 Loss :  0.007936486625112593  Rel. Test L2 Loss :  0.010351956530939788  Test L2 Loss :  0.005171463039005175
Epoch :  230  Rel. Train L2 Loss :  0.005818177392939106  Rel. Test L2 Loss :  0.009853201336227357  Test L2 Loss :  0.005036661037593149
Epoch :  240  Rel. Train L2 Loss :  0.011104872624855489  Rel. Test L2 Loss :  0.01421331736491993  Test L2 Loss :  0.006690551701467484
Epoch :  250  Rel. Train L2 Loss :  0.007700835034484044  Rel. Test L2 Loss :  0.010917078296188265  Test L2 Loss :  0.005546584783587605
Epoch :  260  Rel. Train L2 Loss :  0.0068678474926855415  Rel. Test L2 Loss :  0.011556127166841179  Test L2 Loss :  0.005775771656772122
Epoch :  270  Rel. Train L2 Loss :  0.006902240013005212  Rel. Test L2 Loss :  0.011474454251583666  Test L2 Loss :  0.005823984072776511
Epoch :  280  Rel. Train L2 Loss :  0.007895537943113595  Rel. Test L2 Loss :  0.011459440807811916  Test L2 Loss :  0.005707353935576975
Epoch :  290  Rel. Train L2 Loss :  0.007066981052048504  Rel. Test L2 Loss :  0.012559539813082665  Test L2 Loss :  0.0060487362497951835
Epoch :  300  Rel. Train L2 Loss :  0.007625468191690743  Rel. Test L2 Loss :  0.009424212155863643  Test L2 Loss :  0.004864282789640129
Epoch :  310  Rel. Train L2 Loss :  0.0034248315932927653  Rel. Test L2 Loss :  0.008168773783836514  Test L2 Loss :  0.004320120162446983
Epoch :  320  Rel. Train L2 Loss :  0.0033389161253580824  Rel. Test L2 Loss :  0.008337368431966752  Test L2 Loss :  0.0043621797958621755
Epoch :  330  Rel. Train L2 Loss :  0.0035498587385518476  Rel. Test L2 Loss :  0.008126365311909467  Test L2 Loss :  0.004305309703340754
Epoch :  340  Rel. Train L2 Loss :  0.0035115051723551005  Rel. Test L2 Loss :  0.008424282394116744  Test L2 Loss :  0.004376467477413826
Epoch :  350  Rel. Train L2 Loss :  0.0035720229643629864  Rel. Test L2 Loss :  0.008311846031574532  Test L2 Loss :  0.004340929983300157
Epoch :  360  Rel. Train L2 Loss :  0.0037024067278252915  Rel. Test L2 Loss :  0.008369034912902862  Test L2 Loss :  0.004366784778540023
Epoch :  370  Rel. Train L2 Loss :  0.003762367312447168  Rel. Test L2 Loss :  0.008014965802431107  Test L2 Loss :  0.00422419115784578
Epoch :  380  Rel. Train L2 Loss :  0.005504093074705452  Rel. Test L2 Loss :  0.011670866864733398  Test L2 Loss :  0.0056989417935255915
Epoch :  390  Rel. Train L2 Loss :  0.0038136375951580703  Rel. Test L2 Loss :  0.008301144553115591  Test L2 Loss :  0.004283903515897691

In [None]:
Fourier bases

Epoch :  0  Rel. Train L2 Loss :  0.5343242827802896  Rel. Test L2 Loss :  0.35652035661041737  Test L2 Loss :  0.16434418968856335
Epoch :  10  Rel. Train L2 Loss :  0.05252824304625392  Rel. Test L2 Loss :  0.08707353239879012  Test L2 Loss :  0.038476263638585806
Epoch :  20  Rel. Train L2 Loss :  0.03379512997344136  Rel. Test L2 Loss :  0.034905026433989406  Test L2 Loss :  0.016287033446133137
Epoch :  30  Rel. Train L2 Loss :  0.03300807299092412  Rel. Test L2 Loss :  0.035803209990262985  Test L2 Loss :  0.0160839143791236
Epoch :  40  Rel. Train L2 Loss :  0.02496281557250768  Rel. Test L2 Loss :  0.031191959278658032  Test L2 Loss :  0.01455920428270474
Epoch :  50  Rel. Train L2 Loss :  0.019304369459860027  Rel. Test L2 Loss :  0.025718713994137943  Test L2 Loss :  0.01162591646425426
Epoch :  60  Rel. Train L2 Loss :  0.019517153152264655  Rel. Test L2 Loss :  0.022306667640805244  Test L2 Loss :  0.010376633668784052
Epoch :  70  Rel. Train L2 Loss :  0.02815706399269402  Rel. Test L2 Loss :  0.024218744249083102  Test L2 Loss :  0.011378672847058624
Epoch :  80  Rel. Train L2 Loss :  0.01887817436363548  Rel. Test L2 Loss :  0.02192652691155672  Test L2 Loss :  0.01027321710716933
Epoch :  90  Rel. Train L2 Loss :  0.01543609204236418  Rel. Test L2 Loss :  0.027948095579631627  Test L2 Loss :  0.012622756243217736
Epoch :  100  Rel. Train L2 Loss :  0.018864598590880632  Rel. Test L2 Loss :  0.022070111241191626  Test L2 Loss :  0.010092632728628814
Epoch :  110  Rel. Train L2 Loss :  0.01430836075451225  Rel. Test L2 Loss :  0.022063812939450145  Test L2 Loss :  0.010145931446459144
Epoch :  120  Rel. Train L2 Loss :  0.025544452480971813  Rel. Test L2 Loss :  0.027506385231390595  Test L2 Loss :  0.012613549712114036

In [None]:
Start training  FNO
x_train shape:  torch.Size([800, 106, 106, 3]) y_train shape:  torch.Size([800, 106, 106, 1])
x_test shape:  torch.Size([800, 106, 106, 3]) y_test shape:  torch.Size([800, 106, 106, 1])
Epoch :  0  Rel. Train L2 Loss :  0.4634079229831696  Rel. Test L2 Loss :  0.2793087589740753  Test L2 Loss :  0.0018948454689234496
Epoch :  10  Rel. Train L2 Loss :  0.036598108410835266  Rel. Test L2 Loss :  0.039496539533138274  Test L2 Loss :  0.0002694653638172895