In [None]:
from __future__ import annotations

import pickle
from pathlib import Path
from timeit import default_timer

import numpy as np
import torch
from torch import nn

In [None]:
from __future__ import annotations

import torch
import torch.nn.functional as F
from torch import nn


class SpectralConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, modes1):
        super().__init__()

        """
        1D Fourier layer. It does FFT, linear transform, and Inverse FFT.
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = (
            # Number of Fourier modes to multiply, at most floor(N/2) + 1
            modes1
        )

        self.scale = 1 / (in_channels * out_channels)
        self.weights1 = nn.Parameter(
            self.scale
            * torch.rand(in_channels, out_channels, self.modes1, dtype=torch.cfloat)
        )

    # Complex multiplication
    def compl_mul1d(self, input, weights):
        # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
        return torch.einsum("bix,iox->box", input, weights)

    def forward(self, x):
        batchsize = x.shape[0]
        # Compute Fourier coefficients up to factor of e^(- something constant)
        x_ft = torch.fft.rfft(x)

        # Multiply relevant Fourier modes
        out_ft = torch.zeros(
            batchsize,
            self.out_channels,
            x.size(-1) // 2 + 1,
            device=x.device,
            dtype=torch.cfloat,
        )
        out_ft[:, :, : self.modes1] = self.compl_mul1d(
            x_ft[:, :, : self.modes1], self.weights1
        )

        # Return to physical space
        return torch.fft.irfft(out_ft, n=x.size(-1))


class FNO1d(nn.Module):
    def __init__(self, num_channels, modes=16, width=64, initial_step=10):
        super().__init__()

        """
        The overall network. It contains 4 layers of the Fourier layer.
        1. Lift the input to the desire channel dimension by self.fc0 .
        2. 4 layers of the integral operators u' = (W + K)(u).
            W defined by self.w; K defined by self.conv .
        3. Project from the channel space to the output space by self.fc1 and self.fc2 .

        input: the solution of the initial condition and location (a(x), x)
        input shape: (batchsize, x=s, c=2)
        output: the solution of a later timestep
        output shape: (batchsize, x=s, c=1)
        """

        self.modes1 = modes
        self.width = width
        self.padding = 2  # pad the domain if input is non-periodic
        self.fc0 = nn.Linear(
            initial_step * num_channels + 1, self.width
        )  # input channel is 2: (a(x), x)

        self.conv0 = SpectralConv1d(self.width, self.width, self.modes1)
        self.conv1 = SpectralConv1d(self.width, self.width, self.modes1)
        self.conv2 = SpectralConv1d(self.width, self.width, self.modes1)
        self.conv3 = SpectralConv1d(self.width, self.width, self.modes1)
        self.w0 = nn.Conv1d(self.width, self.width, 1)
        self.w1 = nn.Conv1d(self.width, self.width, 1)
        self.w2 = nn.Conv1d(self.width, self.width, 1)
        self.w3 = nn.Conv1d(self.width, self.width, 1)

        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, num_channels)

    def forward(self, x, grid):
        # x dim = [b, x1, t*v]
        x = torch.cat((x, grid), dim=-1)
        x = self.fc0(x)
        x = x.permute(0, 2, 1)

        # pad the domain if input is non-periodic
        x = F.pad(x, [0, self.padding])

        x1 = self.conv0(x)
        x2 = self.w0(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv2(x)
        x2 = self.w2(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv3(x)
        x2 = self.w3(x)
        x = x1 + x2

        x = x[..., : -self.padding]
        x = x.permute(0, 2, 1)
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x.unsqueeze(-2)

In [None]:
from __future__ import annotations

import math as mt
from pathlib import Path

import h5py
import numpy as np
import torch
from torch.utils.data import Dataset


class FNODatasetSingle(Dataset):
    def __init__(
        self,
        filename,
        initial_step=10,
        saved_folder="../data/",
        reduced_resolution=1,
        reduced_resolution_t=1,
        reduced_batch=1,
        if_test=False,
        test_ratio=0.1,
        num_samples_max=-1,
    ):
        """

        :param filename: filename that contains the dataset
        :type filename: STR
        :param filenum: array containing indices of filename included in the dataset
        :type filenum: ARRAY
        :param initial_step: time steps taken as initial condition, defaults to 10
        :type initial_step: INT, optional

        """

        # Define path to files
        root_path = Path(Path(saved_folder).resolve()) / filename
        if filename[-2:] != "h5":
            # print(".HDF5 file extension is assumed hereafter")

            with h5py.File(root_path, "r") as f:
                keys = list(f.keys())
                keys.sort()
                if "tensor" not in keys:
                    _data = np.array(
                        f["density"], dtype=np.float32
                    )  # batch, time, x,...
                    idx_cfd = _data.shape
                    if len(idx_cfd) == 3:  # 1D
                        self.data = np.zeros(
                            [
                                idx_cfd[0] // reduced_batch,
                                idx_cfd[2] // reduced_resolution,
                                mt.ceil(idx_cfd[1] / reduced_resolution_t),
                                3,
                            ],
                            dtype=np.float32,
                        )
                        # density
                        _data = _data[
                            ::reduced_batch,
                            ::reduced_resolution_t,
                            ::reduced_resolution,
                        ]
                        ## convert to [x1, ..., xd, t, v]
                        _data = np.transpose(_data[:, :, :], (0, 2, 1))
                        self.data[..., 0] = _data  # batch, x, t, ch
                        # pressure
                        _data = np.array(
                            f["pressure"], dtype=np.float32
                        )  # batch, time, x,...
                        _data = _data[
                            ::reduced_batch,
                            ::reduced_resolution_t,
                            ::reduced_resolution,
                        ]
                        ## convert to [x1, ..., xd, t, v]
                        _data = np.transpose(_data[:, :, :], (0, 2, 1))
                        self.data[..., 1] = _data  # batch, x, t, ch
                        # Vx
                        _data = np.array(
                            f["Vx"], dtype=np.float32
                        )  # batch, time, x,...
                        _data = _data[
                            ::reduced_batch,
                            ::reduced_resolution_t,
                            ::reduced_resolution,
                        ]
                        ## convert to [x1, ..., xd, t, v]
                        _data = np.transpose(_data[:, :, :], (0, 2, 1))
                        self.data[..., 2] = _data  # batch, x, t, ch

                        self.grid = np.array(f["x-coordinate"], dtype=np.float32)
                        self.grid = torch.tensor(
                            self.grid[::reduced_resolution], dtype=torch.float
                        ).unsqueeze(-1)
                        # print(self.data.shape)
                    if len(idx_cfd) == 4:  # 2D
                        self.data = np.zeros(
                            [
                                idx_cfd[0] // reduced_batch,
                                idx_cfd[2] // reduced_resolution,
                                idx_cfd[3] // reduced_resolution,
                                mt.ceil(idx_cfd[1] / reduced_resolution_t),
                                4,
                            ],
                            dtype=np.float32,
                        )
                        # density
                        _data = _data[
                            ::reduced_batch,
                            ::reduced_resolution_t,
                            ::reduced_resolution,
                            ::reduced_resolution,
                        ]
                        ## convert to [x1, ..., xd, t, v]
                        _data = np.transpose(_data, (0, 2, 3, 1))
                        self.data[..., 0] = _data  # batch, x, t, ch
                        # pressure
                        _data = np.array(
                            f["pressure"], dtype=np.float32
                        )  # batch, time, x,...
                        _data = _data[
                            ::reduced_batch,
                            ::reduced_resolution_t,
                            ::reduced_resolution,
                            ::reduced_resolution,
                        ]
                        ## convert to [x1, ..., xd, t, v]
                        _data = np.transpose(_data, (0, 2, 3, 1))
                        self.data[..., 1] = _data  # batch, x, t, ch
                        # Vx
                        _data = np.array(
                            f["Vx"], dtype=np.float32
                        )  # batch, time, x,...
                        _data = _data[
                            ::reduced_batch,
                            ::reduced_resolution_t,
                            ::reduced_resolution,
                            ::reduced_resolution,
                        ]
                        ## convert to [x1, ..., xd, t, v]
                        _data = np.transpose(_data, (0, 2, 3, 1))
                        self.data[..., 2] = _data  # batch, x, t, ch
                        # Vy
                        _data = np.array(
                            f["Vy"], dtype=np.float32
                        )  # batch, time, x,...
                        _data = _data[
                            ::reduced_batch,
                            ::reduced_resolution_t,
                            ::reduced_resolution,
                            ::reduced_resolution,
                        ]
                        ## convert to [x1, ..., xd, t, v]
                        _data = np.transpose(_data, (0, 2, 3, 1))
                        self.data[..., 3] = _data  # batch, x, t, ch

                        x = np.array(f["x-coordinate"], dtype=np.float32)
                        y = np.array(f["y-coordinate"], dtype=np.float32)
                        x = torch.tensor(x, dtype=torch.float)
                        y = torch.tensor(y, dtype=torch.float)
                        X, Y = torch.meshgrid(x, y, indexing="ij")
                        self.grid = torch.stack((X, Y), axis=-1)[
                            ::reduced_resolution, ::reduced_resolution
                        ]

                    if len(idx_cfd) == 5:  # 3D
                        self.data = np.zeros(
                            [
                                idx_cfd[0] // reduced_batch,
                                idx_cfd[2] // reduced_resolution,
                                idx_cfd[3] // reduced_resolution,
                                idx_cfd[4] // reduced_resolution,
                                mt.ceil(idx_cfd[1] / reduced_resolution_t),
                                5,
                            ],
                            dtype=np.float32,
                        )
                        # density
                        _data = _data[
                            ::reduced_batch,
                            ::reduced_resolution_t,
                            ::reduced_resolution,
                            ::reduced_resolution,
                            ::reduced_resolution,
                        ]
                        ## convert to [x1, ..., xd, t, v]
                        _data = np.transpose(_data, (0, 2, 3, 4, 1))
                        self.data[..., 0] = _data  # batch, x, t, ch
                        # pressure
                        _data = np.array(
                            f["pressure"], dtype=np.float32
                        )  # batch, time, x,...
                        _data = _data[
                            ::reduced_batch,
                            ::reduced_resolution_t,
                            ::reduced_resolution,
                            ::reduced_resolution,
                            ::reduced_resolution,
                        ]
                        ## convert to [x1, ..., xd, t, v]
                        _data = np.transpose(_data, (0, 2, 3, 4, 1))
                        self.data[..., 1] = _data  # batch, x, t, ch
                        # Vx
                        _data = np.array(
                            f["Vx"], dtype=np.float32
                        )  # batch, time, x,...
                        _data = _data[
                            ::reduced_batch,
                            ::reduced_resolution_t,
                            ::reduced_resolution,
                            ::reduced_resolution,
                            ::reduced_resolution,
                        ]
                        ## convert to [x1, ..., xd, t, v]
                        _data = np.transpose(_data, (0, 2, 3, 4, 1))
                        self.data[..., 2] = _data  # batch, x, t, ch
                        # Vy
                        _data = np.array(
                            f["Vy"], dtype=np.float32
                        )  # batch, time, x,...
                        _data = _data[
                            ::reduced_batch,
                            ::reduced_resolution_t,
                            ::reduced_resolution,
                            ::reduced_resolution,
                            ::reduced_resolution,
                        ]
                        ## convert to [x1, ..., xd, t, v]
                        _data = np.transpose(_data, (0, 2, 3, 4, 1))
                        self.data[..., 3] = _data  # batch, x, t, ch
                        # Vz
                        _data = np.array(
                            f["Vz"], dtype=np.float32
                        )  # batch, time, x,...
                        _data = _data[
                            ::reduced_batch,
                            ::reduced_resolution_t,
                            ::reduced_resolution,
                            ::reduced_resolution,
                            ::reduced_resolution,
                        ]
                        ## convert to [x1, ..., xd, t, v]
                        _data = np.transpose(_data, (0, 2, 3, 4, 1))
                        self.data[..., 4] = _data  # batch, x, t, ch

                        x = np.array(f["x-coordinate"], dtype=np.float32)
                        y = np.array(f["y-coordinate"], dtype=np.float32)
                        z = np.array(f["z-coordinate"], dtype=np.float32)
                        x = torch.tensor(x, dtype=torch.float)
                        y = torch.tensor(y, dtype=torch.float)
                        z = torch.tensor(z, dtype=torch.float)
                        X, Y, Z = torch.meshgrid(x, y, z, indexing="ij")
                        self.grid = torch.stack((X, Y, Z), axis=-1)[
                            ::reduced_resolution,
                            ::reduced_resolution,
                            ::reduced_resolution,
                        ]

                else:  # scalar equations
                    ## data dim = [t, x1, ..., xd, v]
                    _data = np.array(
                        f["tensor"], dtype=np.float32
                    )  # batch, time, x,...
                    if len(_data.shape) == 3:  # 1D
                        _data = _data[
                            ::reduced_batch,
                            ::reduced_resolution_t,
                            ::reduced_resolution,
                        ]
                        ## convert to [x1, ..., xd, t, v]
                        _data = np.transpose(_data[:, :, :], (0, 2, 1))
                        self.data = _data[:, :, :, None]  # batch, x, t, ch

                        self.grid = np.array(f["x-coordinate"], dtype=np.float32)
                        self.grid = torch.tensor(
                            self.grid[::reduced_resolution], dtype=torch.float
                        ).unsqueeze(-1)
                    if len(_data.shape) == 4:  # 2D Darcy flow
                        # u: label
                        _data = _data[
                            ::reduced_batch,
                            :,
                            ::reduced_resolution,
                            ::reduced_resolution,
                        ]
                        ## convert to [x1, ..., xd, t, v]
                        _data = np.transpose(_data[:, :, :, :], (0, 2, 3, 1))
                        # if _data.shape[-1]==1:  # if nt==1
                        #    _data = np.tile(_data, (1, 1, 1, 2))
                        self.data = _data
                        # nu: input
                        _data = np.array(
                            f["nu"], dtype=np.float32
                        )  # batch, time, x,...
                        _data = _data[
                            ::reduced_batch,
                            None,
                            ::reduced_resolution,
                            ::reduced_resolution,
                        ]
                        ## convert to [x1, ..., xd, t, v]
                        _data = np.transpose(_data[:, :, :, :], (0, 2, 3, 1))
                        self.data = np.concatenate([_data, self.data], axis=-1)
                        self.data = self.data[:, :, :, :, None]  # batch, x, y, t, ch

                        x = np.array(f["x-coordinate"], dtype=np.float32)
                        y = np.array(f["y-coordinate"], dtype=np.float32)
                        x = torch.tensor(x, dtype=torch.float)
                        y = torch.tensor(y, dtype=torch.float)
                        X, Y = torch.meshgrid(x, y, indexing="ij")
                        self.grid = torch.stack((X, Y), axis=-1)[
                            ::reduced_resolution, ::reduced_resolution
                        ]

        elif filename[-2:] == "h5":  # SWE-2D (RDB)
            # print(".H5 file extension is assumed hereafter")

            with h5py.File(root_path, "r") as f:
                keys = list(f.keys())
                keys.sort()

                data_arrays = [
                    np.array(f[key]["data"], dtype=np.float32) for key in keys
                ]
                _data = torch.from_numpy(
                    np.stack(data_arrays, axis=0)
                )  # [batch, nt, nx, ny, nc]
                _data = _data[
                    ::reduced_batch,
                    ::reduced_resolution_t,
                    ::reduced_resolution,
                    ::reduced_resolution,
                    ...,
                ]
                _data = torch.permute(_data, (0, 2, 3, 1, 4))  # [batch, nx, ny, nt, nc]
                gridx, gridy = (
                    np.array(f["0023"]["grid"]["x"], dtype=np.float32),
                    np.array(f["0023"]["grid"]["y"], dtype=np.float32),
                )
                mgridX, mgridY = np.meshgrid(gridx, gridy, indexing="ij")
                _grid = torch.stack(
                    (torch.from_numpy(mgridX), torch.from_numpy(mgridY)), axis=-1
                )
                _grid = _grid[::reduced_resolution, ::reduced_resolution, ...]
                _tsteps_t = torch.from_numpy(
                    np.array(f["0023"]["grid"]["t"], dtype=np.float32)
                )

                tsteps_t = _tsteps_t[::reduced_resolution_t]
                self.data = _data
                self.grid = _grid
                self.tsteps_t = tsteps_t

        if num_samples_max > 0:
            num_samples_max = min(num_samples_max, self.data.shape[0])
        else:
            num_samples_max = self.data.shape[0]

        test_idx = int(num_samples_max * test_ratio)
        if if_test:
            self.data = self.data[:test_idx]
        else:
            self.data = self.data[test_idx:num_samples_max]

        # Time steps used as initial conditions
        self.initial_step = initial_step

        self.data = self.data if torch.is_tensor(self.data) else torch.tensor(self.data)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx, ..., : self.initial_step, :], self.data[idx], self.grid

In [None]:
from __future__ import annotations

import logging
import math as mt

import matplotlib.pyplot as plt
import numpy as np
import torch
from mpl_toolkits.axes_grid1 import make_axes_locatable
from torch import nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logger = logging.getLogger(__name__)


def metric_func(
    pred, target, if_mean=True, Lx=1.0, Ly=1.0, Lz=1.0, iLow=4, iHigh=12, initial_step=1
):
    """
    code for calculate metrics discussed in the Brain-storming session
    RMSE, normalized RMSE, max error, RMSE at the boundaries, conserved variables, RMSE in Fourier space, temporal sensitivity
    """
    pred, target = pred.to(device), target.to(device)
    # (batch, nx^i..., timesteps, nc)
    # slice out `initial context` timesteps
    pred = pred[..., initial_step:, :]
    target = target[..., initial_step:, :]
    idxs = target.size()
    if len(idxs) == 4:  # 1D
        pred = pred.permute(0, 3, 1, 2)
        target = target.permute(0, 3, 1, 2)
    if len(idxs) == 5:  # 2D
        pred = pred.permute(0, 4, 1, 2, 3)
        target = target.permute(0, 4, 1, 2, 3)
    elif len(idxs) == 6:  # 3D
        pred = pred.permute(0, 5, 1, 2, 3, 4)
        target = target.permute(0, 5, 1, 2, 3, 4)
    idxs = target.size()
    nb, nc, nt = idxs[0], idxs[1], idxs[-1]

    # RMSE
    err_mean = torch.sqrt(
        torch.mean(
            (pred.view([nb, nc, -1, nt]) - target.view([nb, nc, -1, nt])) ** 2, dim=2
        )
    )
    err_RMSE = torch.mean(err_mean, axis=0)
    nrm = torch.sqrt(torch.mean(target.view([nb, nc, -1, nt]) ** 2, dim=2))
    err_nRMSE = torch.mean(err_mean / nrm, dim=0)

    err_CSV = torch.sqrt(
        torch.mean(
            (
                torch.sum(pred.view([nb, nc, -1, nt]), dim=2)
                - torch.sum(target.view([nb, nc, -1, nt]), dim=2)
            )
            ** 2,
            dim=0,
        )
    )
    if len(idxs) == 4:
        nx = idxs[2]
        err_CSV /= nx
    elif len(idxs) == 5:
        nx, ny = idxs[2:4]
        err_CSV /= nx * ny
    elif len(idxs) == 6:
        nx, ny, nz = idxs[2:5]
        err_CSV /= nx * ny * nz
    # worst case in all the data
    err_Max = torch.max(
        torch.max(
            torch.abs(pred.view([nb, nc, -1, nt]) - target.view([nb, nc, -1, nt])),
            dim=2,
        )[0],
        dim=0,
    )[0]

    if len(idxs) == 4:  # 1D
        err_BD = (pred[:, :, 0, :] - target[:, :, 0, :]) ** 2
        err_BD += (pred[:, :, -1, :] - target[:, :, -1, :]) ** 2
        err_BD = torch.mean(torch.sqrt(err_BD / 2.0), dim=0)
    elif len(idxs) == 5:  # 2D
        nx, ny = idxs[2:4]
        err_BD_x = (pred[:, :, 0, :, :] - target[:, :, 0, :, :]) ** 2
        err_BD_x += (pred[:, :, -1, :, :] - target[:, :, -1, :, :]) ** 2
        err_BD_y = (pred[:, :, :, 0, :] - target[:, :, :, 0, :]) ** 2
        err_BD_y += (pred[:, :, :, -1, :] - target[:, :, :, -1, :]) ** 2
        err_BD = (torch.sum(err_BD_x, dim=-2) + torch.sum(err_BD_y, dim=-2)) / (
            2 * nx + 2 * ny
        )
        err_BD = torch.mean(torch.sqrt(err_BD), dim=0)
    elif len(idxs) == 6:  # 3D
        nx, ny, nz = idxs[2:5]
        err_BD_x = (pred[:, :, 0, :, :] - target[:, :, 0, :, :]) ** 2
        err_BD_x += (pred[:, :, -1, :, :] - target[:, :, -1, :, :]) ** 2
        err_BD_y = (pred[:, :, :, 0, :] - target[:, :, :, 0, :]) ** 2
        err_BD_y += (pred[:, :, :, -1, :] - target[:, :, :, -1, :]) ** 2
        err_BD_z = (pred[:, :, :, :, 0] - target[:, :, :, :, 0]) ** 2
        err_BD_z += (pred[:, :, :, :, -1] - target[:, :, :, :, -1]) ** 2
        err_BD = (
            torch.sum(err_BD_x.contiguous().view([nb, -1, nt]), dim=-2)
            + torch.sum(err_BD_y.contiguous().view([nb, -1, nt]), dim=-2)
            + torch.sum(err_BD_z.contiguous().view([nb, -1, nt]), dim=-2)
        )
        err_BD = err_BD / (2 * nx * ny + 2 * ny * nz + 2 * nz * nx)
        err_BD = torch.sqrt(err_BD)

    if len(idxs) == 4:  # 1D
        nx = idxs[2]
        pred_F = torch.fft.rfft(pred, dim=2)
        target_F = torch.fft.rfft(target, dim=2)
        _err_F = (
            torch.sqrt(torch.mean(torch.abs(pred_F - target_F) ** 2, axis=0)) / nx * Lx
        )
    if len(idxs) == 5:  # 2D
        pred_F = torch.fft.fftn(pred, dim=[2, 3])
        target_F = torch.fft.fftn(target, dim=[2, 3])
        nx, ny = idxs[2:4]
        _err_F = torch.abs(pred_F - target_F) ** 2
        err_F = torch.zeros([nb, nc, min(nx // 2, ny // 2), nt]).to(device)
        for i in range(nx // 2):
            for j in range(ny // 2):
                it = mt.floor(mt.sqrt(i**2 + j**2))
                if it > min(nx // 2, ny // 2) - 1:
                    continue
                err_F[:, :, it] += _err_F[:, :, i, j]
        _err_F = torch.sqrt(torch.mean(err_F, axis=0)) / (nx * ny) * Lx * Ly
    elif len(idxs) == 6:  # 3D
        pred_F = torch.fft.fftn(pred, dim=[2, 3, 4])
        target_F = torch.fft.fftn(target, dim=[2, 3, 4])
        nx, ny, nz = idxs[2:5]
        _err_F = torch.abs(pred_F - target_F) ** 2
        err_F = torch.zeros([nb, nc, min(nx // 2, ny // 2, nz // 2), nt]).to(device)
        for i in range(nx // 2):
            for j in range(ny // 2):
                for k in range(nz // 2):
                    it = mt.floor(mt.sqrt(i**2 + j**2 + k**2))
                    if it > min(nx // 2, ny // 2, nz // 2) - 1:
                        continue
                    err_F[:, :, it] += _err_F[:, :, i, j, k]
        _err_F = torch.sqrt(torch.mean(err_F, axis=0)) / (nx * ny * nz) * Lx * Ly * Lz

    err_F = torch.zeros([nc, 3, nt]).to(device)
    err_F[:, 0] += torch.mean(_err_F[:, :iLow], dim=1)  # low freq
    err_F[:, 1] += torch.mean(_err_F[:, iLow:iHigh], dim=1)  # middle freq
    err_F[:, 2] += torch.mean(_err_F[:, iHigh:], dim=1)  # high freq

    if if_mean:
        return (
            torch.mean(err_RMSE, dim=[0, -1]),
            torch.mean(err_nRMSE, dim=[0, -1]),
            torch.mean(err_CSV, dim=[0, -1]),
            torch.mean(err_Max, dim=[0, -1]),
            torch.mean(err_BD, dim=[0, -1]),
            torch.mean(err_F, dim=[0, -1]),
        )
    return err_RMSE, err_nRMSE, err_CSV, err_Max, err_BD, err_F


def metrics(
    val_loader,
    model,
    Lx,
    Ly,
    Lz,
    plot,
    channel_plot,
    model_name,
    x_min,
    x_max,
    y_min,
    y_max,
    t_min,
    t_max,
    mode="FNO",
    initial_step=None,
):
    if mode == "Unet":
        with torch.no_grad():
            for itot, (xx, yy) in enumerate(val_loader):
                xx = xx.to(device)  # noqa: PLW2901
                yy = yy.to(device)  # noqa: PLW2901

                pred = yy[..., :initial_step, :]
                inp_shape = list(xx.shape)
                inp_shape = inp_shape[:-2]
                inp_shape.append(-1)

                for _t in range(initial_step, yy.shape[-2]):
                    inp = xx.reshape(inp_shape)
                    temp_shape = [0, -1]
                    temp_shape.extend(list(range(1, len(inp.shape) - 1)))
                    inp = inp.permute(temp_shape)

                    temp_shape = [0]
                    temp_shape.extend(list(range(2, len(inp.shape))))
                    temp_shape.append(1)
                    im = model(inp).permute(temp_shape).unsqueeze(-2)
                    pred = torch.cat((pred, im), -2)
                    xx = torch.cat((xx[..., 1:, :], im), dim=-2)  # noqa: PLW2901

                (
                    _err_RMSE,
                    _err_nRMSE,
                    _err_CSV,
                    _err_Max,
                    _err_BD,
                    _err_F,
                ) = metric_func(
                    pred,
                    yy,
                    if_mean=True,
                    Lx=Lx,
                    Ly=Ly,
                    Lz=Lz,
                    initial_step=initial_step,
                )

                if itot == 0:
                    err_RMSE, err_nRMSE, err_CSV, err_Max, err_BD, err_F = (
                        _err_RMSE,
                        _err_nRMSE,
                        _err_CSV,
                        _err_Max,
                        _err_BD,
                        _err_F,
                    )
                    pred_plot = pred[:1]
                    target_plot = yy[:1]
                    val_l2_time = torch.zeros(yy.shape[-2]).to(device)
                else:
                    err_RMSE += _err_RMSE
                    err_nRMSE += _err_nRMSE
                    err_CSV += _err_CSV
                    err_Max += _err_Max
                    err_BD += _err_BD
                    err_F += _err_F

                    mean_dim = list(range(len(yy.shape) - 2))
                    mean_dim.append(-1)
                    mean_dim = tuple(mean_dim)
                    val_l2_time += torch.sqrt(
                        torch.mean((pred - yy) ** 2, dim=mean_dim)
                    )

    elif mode == "FNO":
        with torch.no_grad():
            itot = 0
            for itot, (xx, yy, grid) in enumerate(val_loader):
                xx = xx.to(device)  # noqa: PLW2901
                yy = yy.to(device)  # noqa: PLW2901
                grid = grid.to(device)  # noqa: PLW2901

                pred = yy[..., :initial_step, :]
                inp_shape = list(xx.shape)
                inp_shape = inp_shape[:-2]
                inp_shape.append(-1)

                for _t in range(initial_step, yy.shape[-2]):
                    inp = xx.reshape(inp_shape)
                    im = model(inp, grid)
                    pred = torch.cat((pred, im), -2)
                    xx = torch.cat((xx[..., 1:, :], im), dim=-2)  # noqa: PLW2901

                (
                    _err_RMSE,
                    _err_nRMSE,
                    _err_CSV,
                    _err_Max,
                    _err_BD,
                    _err_F,
                ) = metric_func(
                    pred,
                    yy,
                    if_mean=True,
                    Lx=Lx,
                    Ly=Ly,
                    Lz=Lz,
                    initial_step=initial_step,
                )
                if itot == 0:
                    err_RMSE, err_nRMSE, err_CSV, err_Max, err_BD, err_F = (
                        _err_RMSE,
                        _err_nRMSE,
                        _err_CSV,
                        _err_Max,
                        _err_BD,
                        _err_F,
                    )
                    pred_plot = pred[:1]
                    target_plot = yy[:1]
                    val_l2_time = torch.zeros(yy.shape[-2]).to(device)
                else:
                    err_RMSE += _err_RMSE
                    err_nRMSE += _err_nRMSE
                    err_CSV += _err_CSV
                    err_Max += _err_Max
                    err_BD += _err_BD
                    err_F += _err_F

                    mean_dim = list(range(len(yy.shape) - 2))
                    mean_dim.append(-1)
                    mean_dim = tuple(mean_dim)
                    val_l2_time += torch.sqrt(
                        torch.mean((pred - yy) ** 2, dim=mean_dim)
                    )

    elif mode == "PINN":
        raise NotImplementedError

    err_RMSE = np.array(err_RMSE.data.cpu() / itot)
    err_nRMSE = np.array(err_nRMSE.data.cpu() / itot)
    err_CSV = np.array(err_CSV.data.cpu() / itot)
    err_Max = np.array(err_Max.data.cpu() / itot)
    err_BD = np.array(err_BD.data.cpu() / itot)
    err_F = np.array(err_F.data.cpu() / itot)
    logger.info(f"RMSE: {err_RMSE:.5f}")
    logger.info(f"normalized RMSE: {err_nRMSE:.5f}")
    logger.info(f"RMSE of conserved variables: {err_CSV:.5f}")
    logger.info(f"Maximum value of rms error: {err_Max:.5f}")
    logger.info(f"RMSE at boundaries: {err_BD:.5f}")
    logger.info(f"RMSE in Fourier space: {err_F}")

    val_l2_time = val_l2_time / itot

    if plot:
        dim = len(yy.shape) - 3
        plt.ioff()
        if dim == 1:
            fig, ax = plt.subplots(figsize=(6.5, 6))
            h = ax.imshow(
                pred_plot[..., channel_plot].squeeze().detach().cpu(),
                extent=[t_min, t_max, x_min, x_max],
                origin="lower",
                aspect="auto",
            )
            h.set_clim(
                target_plot[..., channel_plot].min(),
                target_plot[..., channel_plot].max(),
            )
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            cbar = fig.colorbar(h, cax=cax)
            cbar.ax.tick_params(labelsize=30)
            ax.set_title("Prediction", fontsize=30)
            ax.tick_params(axis="x", labelsize=30)
            ax.tick_params(axis="y", labelsize=30)
            ax.set_ylabel("$x$", fontsize=30)
            ax.set_xlabel("$t$", fontsize=30)
            plt.tight_layout()
            filename = model_name + "_pred.pdf"
            plt.savefig(filename)

            fig, ax = plt.subplots(figsize=(6.5, 6))
            h = ax.imshow(
                target_plot[..., channel_plot].squeeze().detach().cpu(),
                extent=[t_min, t_max, x_min, x_max],
                origin="lower",
                aspect="auto",
            )
            h.set_clim(
                target_plot[..., channel_plot].min(),
                target_plot[..., channel_plot].max(),
            )
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            cbar = fig.colorbar(h, cax=cax)
            cbar.ax.tick_params(labelsize=30)
            ax.set_title("Data", fontsize=30)
            ax.tick_params(axis="x", labelsize=30)
            ax.tick_params(axis="y", labelsize=30)
            ax.set_ylabel("$x$", fontsize=30)
            ax.set_xlabel("$t$", fontsize=30)
            plt.tight_layout()
            filename = model_name + "_data.pdf"
            plt.savefig(filename)

        elif dim == 2:
            fig, ax = plt.subplots(figsize=(6.5, 6))
            h = ax.imshow(
                pred_plot[..., -1, channel_plot].squeeze().t().detach().cpu(),
                extent=[x_min, x_max, y_min, y_max],
                origin="lower",
                aspect="auto",
            )
            h.set_clim(
                target_plot[..., -1, channel_plot].min(),
                target_plot[..., -1, channel_plot].max(),
            )
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            cbar = fig.colorbar(h, cax=cax)
            cbar.ax.tick_params(labelsize=30)
            ax.set_title("Prediction", fontsize=30)
            ax.tick_params(axis="x", labelsize=30)
            ax.tick_params(axis="y", labelsize=30)
            ax.set_ylabel("$y$", fontsize=30)
            ax.set_xlabel("$x$", fontsize=30)
            plt.tight_layout()
            filename = model_name + "_pred.pdf"
            plt.savefig(filename)

            fig, ax = plt.subplots(figsize=(6.5, 6))
            h = ax.imshow(
                target_plot[..., -1, channel_plot].squeeze().t().detach().cpu(),
                extent=[x_min, x_max, y_min, y_max],
                origin="lower",
                aspect="auto",
            )
            h.set_clim(
                target_plot[..., -1, channel_plot].min(),
                target_plot[..., -1, channel_plot].max(),
            )
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            cbar = fig.colorbar(h, cax=cax)
            cbar.ax.tick_params(labelsize=30)
            ax.set_title("Data", fontsize=30)
            ax.tick_params(axis="x", labelsize=30)
            ax.tick_params(axis="y", labelsize=30)
            ax.set_ylabel("$y$", fontsize=30)
            ax.set_xlabel("$x$", fontsize=30)
            plt.tight_layout()
            filename = model_name + "_data.pdf"
            plt.savefig(filename)

        # plt.figure(figsize=(8,8))
        # plt.semilogy(torch.arange(initial_step,yy.shape[-2]),
        #              val_l2_time[initial_step:].detach().cpu())
        # plt.xlabel('$t$', fontsize=30)
        # plt.ylabel('$MSE$', fontsize=30)
        # plt.title('MSE vs unrolled time steps', fontsize=30)
        # plt.tight_layout()
        # filename = model_name + '_mse_time.pdf'
        # plt.savefig(filename)

        filename = model_name + "mse_time.npz"
        np.savez(
            filename,
            t=torch.arange(initial_step, yy.shape[-2]).cpu(),
            mse=val_l2_time[initial_step:].detach().cpu(),
        )

    return err_RMSE, err_nRMSE, err_CSV, err_Max, err_BD, err_F

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -------------------- Configuration --------------------
filename = "1D_Advection_Sols_beta0.1_reduced.hdf5"
base_path = "/content/drive/MyDrive/"

# Training settings
if_training = True
continue_training = False
num_workers = 0
batch_size = 50
initial_step = 10
t_train = 200
epochs = 500
learning_rate = 1e-3
scheduler_step = 100
scheduler_gamma = 0.5
model_update = 1

# FNO model parameters
num_channels = 1
modes = 12
width = 20

N_layers = 4
N_res    = 4
N_res_neck = 4
channel_multiplier = 16

# Dataset preprocessing options
single_file = True
reduced_resolution = 1
reduced_resolution_t = 1
reduced_batch = 1

# Plotting and bounds
plot = True
channel_plot = True
x_min, x_max = 0.0, 1.0
y_min, y_max = 0.0, 1.0
t_min, t_max = 0.0, 1.0

training_type = "autoregressive"

In [None]:
# -------------------- Load Data --------------------
model_name = filename[:-5] + "_FNO"
model_path = model_name + ".pt"

In [None]:
train_data = FNODatasetSingle(
    filename,
    reduced_resolution=reduced_resolution,
    reduced_resolution_t=reduced_resolution_t,
    reduced_batch=reduced_batch,
    initial_step=initial_step,
    saved_folder=base_path,
)

In [None]:
val_data = FNODatasetSingle(
    filename,
    reduced_resolution=reduced_resolution,
    reduced_resolution_t=reduced_resolution_t,
    reduced_batch=reduced_batch,
    initial_step=initial_step,
    if_test=True,
    saved_folder=base_path,
)

In [None]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=num_workers, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, num_workers=num_workers, shuffle=False)

In [None]:
_, _data, _ = next(iter(val_loader))
t_train = min(t_train, _data.shape[-2])

In [None]:
model = FNO1d(
    num_channels=num_channels,
    width=width,
    modes=modes,
    initial_step=initial_step,
).to(device)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma)
loss_fn = nn.MSELoss(reduction="mean")
loss_val_min = np.inf
start_epoch = 0

In [None]:
for ep in range(start_epoch, epochs):
    model.train()
    t1 = default_timer()
    train_l2_step, train_l2_full = 0, 0

    for xx, yy, grid in train_loader:
        loss = 0
        xx, yy, grid = xx.to(device), yy.to(device), grid.to(device)
        pred = yy[..., :initial_step, :]
        inp_shape = list(xx.shape[:-2]) + [-1]

        for t in range(initial_step, t_train):
            inp = xx.reshape(inp_shape)
            y = yy[..., t : t + 1, :]
            im = model(inp, grid)
            _batch = im.size(0)
            loss += loss_fn(im.reshape(_batch, -1), y.reshape(_batch, -1))
            pred = torch.cat((pred, im), -2)
            xx = torch.cat((xx[..., 1:, :], im), dim=-2)

        train_l2_step += loss.item()
        _batch = yy.size(0)
        _yy = yy[..., :t_train, :]
        l2_full = loss_fn(pred.reshape(_batch, -1), _yy.reshape(_batch, -1))
        train_l2_full += l2_full.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    if ep % model_update == 0:
        val_l2_step, val_l2_full = 0, 0
        model.eval()
        with torch.no_grad():
            for xx, yy, grid in val_loader:
                loss = 0
                xx, yy, grid = xx.to(device), yy.to(device), grid.to(device)
                pred = yy[..., :initial_step, :]
                inp_shape = list(xx.shape[:-2]) + [-1]

                for t in range(initial_step, yy.shape[-2]):
                    inp = xx.reshape(inp_shape)
                    y = yy[..., t : t + 1, :]
                    im = model(inp, grid)
                    _batch = im.size(0)
                    loss += loss_fn(im.reshape(_batch, -1), y.reshape(_batch, -1))
                    pred = torch.cat((pred, im), -2)
                    xx = torch.cat((xx[..., 1:, :], im), dim=-2)

                val_l2_step += loss.item()
                _pred = pred[..., initial_step:t_train, :]
                _yy = yy[..., initial_step:t_train, :]
                val_l2_full += loss_fn(_pred.reshape(_batch, -1), _yy.reshape(_batch, -1)).item()

            if val_l2_full < loss_val_min:
                loss_val_min = val_l2_full
                torch.save(
                    {
                        "epoch": ep,
                        "model_state_dict": model.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "loss": loss_val_min,
                    },
                    model_path,
                )
    t2 = default_timer()
    scheduler.step()
    print("epoch: {0}, loss: {1:.5f}, t2-t1: {2:.5f}, trainL2: {3:.5f}, testL2: {4:.5f}".format(ep, loss.item(), t2 - t1, train_l2_full, val_l2_full))

print("Training complete.")

epoch: 0, loss: 0.07390, t2-t1: 39.82128, trainL2: 14.83780, testL2: 0.04598
epoch: 1, loss: 0.04386, t2-t1: 37.53585, trainL2: 0.23311, testL2: 0.02793
epoch: 2, loss: 0.03394, t2-t1: 39.87717, trainL2: 0.16751, testL2: 0.02176
epoch: 3, loss: 0.03089, t2-t1: 37.49741, trainL2: 0.13998, testL2: 0.01925
epoch: 4, loss: 0.02565, t2-t1: 37.52052, trainL2: 0.11707, testL2: 0.01595
epoch: 5, loss: 0.02725, t2-t1: 37.50794, trainL2: 0.11484, testL2: 0.01648
epoch: 6, loss: 0.02504, t2-t1: 37.14256, trainL2: 0.11605, testL2: 0.01477
epoch: 7, loss: 0.04539, t2-t1: 37.36807, trainL2: 0.13849, testL2: 0.02435
epoch: 8, loss: 0.02738, t2-t1: 38.31197, trainL2: 0.20923, testL2: 0.01726
epoch: 9, loss: 0.04240, t2-t1: 37.46873, trainL2: 0.08299, testL2: 0.02978
epoch: 10, loss: 0.01879, t2-t1: 37.38861, trainL2: 0.07575, testL2: 0.01060
epoch: 11, loss: 0.02050, t2-t1: 37.58353, trainL2: 0.07073, testL2: 0.01180
epoch: 12, loss: 0.03899, t2-t1: 37.46557, trainL2: 0.12429, testL2: 0.02307
epoch: 1