In [1]:
import os
import time 
import scipy
import numpy as np 
import matplotlib 
import matplotlib.pyplot as plt 
plt.style.use('./utils/tecplot.mplstyle')

import torch 
import torch.nn as nn 
import torch.nn.functional as F
import torch.optim as optim
import torch.utils as utils 
from torch.utils.data import Dataset, DataLoader
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# import neuralop 
# from utilities import *
from utils.neuraloperator import * 

from pathlib import Path
from torch import Tensor
from typing import Any, List, Tuple, Mapping, Optional, Iterable, Union, Dict, Literal

## libraries for CFD
# import cupy as cp
import h5py 
import yaml

from tqdm import tqdm
import wandb

cuda


In [2]:
config={
    'name': '2dHIT_FNO', 
    'model': {
        'name': 'FNO',
        'params': {
            'n_modes': (32, 32), 
            'in_channels': 1, 
            'out_channels': 1, 
            'hidden_channels': 32, 
            'lifting_channel_ratio': 4, 
            'projection_channel_ratio': 4, 

            # 'factorization':'tucker',
            # 'implementation':'factorized',
            # 'rank': 2, 

            'norm': 'instance_norm', # (None, 'instance_norm', 'group_norm', 'ada_in'); 'ada_in'은 에러남 
            # 'fno_skip': 'linear', # ('linear', 'soft-gating', 'identity')
            # 'channel_mlp_skip': 'linear', # ('soft-gating', 'linear', 'identity')
            # 'positional_embedding': None, # (None, 'grid', GridEmbedding2D, GridEmbeddingnD)
            # 'implementation': 'factorized', # ('factorized', 'reconstructed')
            # 'fft_norm': 'forward', # ('forward', 'ortho', 'backward')
            # 'fno_block_precision': 'full', # ('full', 'mixed', 'half')

            # 'SpectralConv_initializer': 'zeros', # ('normal', 'uniform', 'constant', 'ones', 'zeros', 'eye', 'dirac', 'xavier_uniform', 'xavier_normal', 'kaiming_uniform', 'kaiming_normal', 'trunc_normal', 'orthogonal', 'sparse')
            # # 'SpectralConv_initializer_param': 0.1, # (None, ) 
            # 'SpectralConv_non_linearity': 'gelu', # ('gelu', 'relu', 'silu', 'tanh', 'sigmoid', 'leakyrelu') 
        
        },
        'device': device,
    },
    
    'epochs': 50, 
    'optimizer': {
        'loss_fn': 'h1',
        'name': 'Adam',
        'params': {
            'lr': 1e-3,
        },

        'scheduler': {
            'name': 'CosineAnnealingLR', 
        },
    },
    'data': [
       { 'nu': 0.000225, 'leadtime': 0.25, # 'idx_leadtime': 1, 
         'stage': 'fit',
        'params': {
            'base_path': r'D:\RESEARCH\2DIso\Data\nu=0.00025_n=256_fDNS=64/', 
            'dataset_name': f'2dHIT_nu=0.000225_n=256_T=14.5_fDNS=64', 
            
            'n_input': 1, 
            'n_output':1, 
            'batch_size': 64,
            'num_workers': 0, 

            'Ndata_train':500,  
            'Ndata_val':50, 
        },
        'normalization': {
            'normalization_path': r'D:\RESEARCH\2DIso\Data\nu=0.00025_n=256_fDNS=64/' + 'config.yaml',
        }
        },
        {
         'nu': 0.000225, 'leadtime': 0.25, # 'idx_leadtime': 1, 
         'stage': 'valid',
        'params': {
            'base_path': r'D:\RESEARCH\2DIso\Data\nu=0.00025_n=256/', 
            'dataset_name': f'2dHIT_nu=0.000225_n=256_T=14.5', 
            
            'n_input': 1, 
            'n_output':1, 
            'batch_size': 64,
            'num_workers': 0, 

            'Ndata_val':1, 
            # 'load_data': False, 
        },
        'normalization': {
            'normalization_path': r'D:\RESEARCH\2DIso\Data\nu=0.00025_n=256/' + 'config.yaml',
        }
        },
        # {
        #  'nu': 0.000225, 'leadtime': 0.25, # 'idx_leadtime': 1, 
        #  'stage': 'valid',
        # 'params': {
        #     'base_path': r'D:\RESEARCH\2DIso\Data\nu=0.00025_n=256_fDNS=64/', 
        #     'dataset_name': f'2dHIT_nu=0.000225_n=256_T=14.5_fDNS=64', 
        #     'target_base_path': r'D:\RESEARCH\2DIso\Data\nu=0.00025_n=256/', 
        #     'target_dataset_name': f'2dHIT_nu=0.000225_n=256_T=14.5', 
            
        #     'n_input': 1, 
        #     'n_output':1, 
        #     'batch_size': 64,
        #     'num_workers': 0, 

        #     'Ndata_val':1, 
        #     # 'load_data': False, 
        # },
        # 'normalization': {
        #     'normalization_path': r'D:\RESEARCH\2DIso\Data\nu=0.00025_n=256/' + 'config.yaml',
        # }
        # },
       
    ]
    
}

In [3]:
### CFD equation parameters ###
integral_timescales = {
    1e-3: 1.870162606239319,
    5e-4: 1.5808453559875488,
    0.000225: 1.3786518573760986, 
    1e-4: 1.3038111189,
    5e-05: 1.15829598903656, 
}

In [4]:
from utils.CFDFunction import *
from utils.Losses import *
from utils.Plots import *

In [5]:
import os
import numpy as np 
import torch 
import h5py
import glob
import time
from typing import Optional, Sequence

def high2low_box(
    x: torch.Tensor,
    scale_factors: Union[float, Sequence[float]] = 0.5,
    *,
    keepdim: bool = False,
    dim: int = 2,
    fft_norm: str = "backward",
    last_var_axis: bool = False,
) -> torch.Tensor:
    """Fourier‑domain box‑style scaling (down / up sampling).

    The function supports tensors shaped either as
        • (batch, channels, *spatial_dims)
        • (batch, channels, *spatial_dims, num_var)

    In the second case the last axis (``num_var``) is *not* included in the FFT
    and therefore remains unchanged.
    """

    if isinstance(scale_factors, float):
        scale_factors = [scale_factors] * dim
    scale_factors = list(scale_factors)
    if len(scale_factors) != dim:
        raise ValueError("`scale_factors` length must equal `dim`.")

    if x.ndim < dim + 2:
        raise ValueError(
            f"Input tensor must have at least {dim + 2} dimensions (got {x.ndim})."
        )

    if last_var_axis:
        spatial_axes = list(range(-(dim + 1), -1))
    else:
        spatial_axes = list(range(-dim, 0))
    spatial_shape = [x.shape[i] for i in spatial_axes]

    sf_tensor = torch.tensor(scale_factors, dtype=x.dtype, device=x.device)
    if torch.allclose(sf_tensor, torch.ones_like(sf_tensor)):
        return x.clone()

    Fx = torch.fft.fftn(x, dim=spatial_axes, norm=fft_norm)

    if (sf_tensor < 1).any():
        Fx = torch.fft.fftshift(Fx, dim=spatial_axes)

        if keepdim:
            mask = torch.ones_like(Fx, dtype=torch.bool)
            for rel_ax, (n, sf) in enumerate(zip(spatial_shape, scale_factors)):
                trim = int(round((1 - sf) * n / 2))
                if trim == 0:
                    continue
                abs_ax = spatial_axes[rel_ax] % Fx.ndim  # positive index
                low = torch.arange(trim, device=x.device)
                high = torch.arange(n - trim, n, device=x.device)
                mask.index_fill_(abs_ax, low, False)
                mask.index_fill_(abs_ax, high, False)
            Fx = Fx * mask
        else:
            slices = [slice(None)] * Fx.ndim
            for rel_ax, (n, sf) in enumerate(zip(spatial_shape, scale_factors)):
                trim = int(round((1 - sf) * n / 2))
                slices[spatial_axes[rel_ax]] = slice(trim, n - trim)
            Fx = Fx[tuple(slices)]

        Fx = torch.fft.ifftshift(Fx, dim=spatial_axes)

    # ----------------------------- up‑sampling branch -------------------------
    if (sf_tensor > 1).any():
        Fx = torch.fft.fftshift(Fx, dim=spatial_axes)

        # F.pad pads the *last* k dims; build list accordingly
        last_axes = list(range(-len(spatial_axes) - (1 if last_var_axis else 0), 0))
        pads: List[int] = []
        for ax in Reflectiond(last_axes):
            if ax in spatial_axes:
                rel = spatial_axes.index(ax)
                n = spatial_shape[rel]
                sf = scale_factors[rel]
                extra = int(round((sf - 1) * n))
                pads.extend([extra // 2, extra - extra // 2])
            else:  # var axis → no pad
                pads.extend([0, 0])
        Fx = F.pad(Fx, pads, mode="constant", value=0.0)
        Fx = torch.fft.ifftshift(Fx, dim=spatial_axes)

    # --------------------------------------------------- inverse FFT ----------
    x_out = torch.fft.ifftn(Fx, dim=spatial_axes, norm=fft_norm).real

    # ------------------------------ match statistics --------------------------
    def _mom(t: torch.Tensor):
        mean = t.mean(dim=spatial_axes, keepdim=True)
        std = t.std(dim=spatial_axes, unbiased=False, keepdim=True).clamp_min(1e-12)
        return mean, std

    mean_in, std_in = _mom(x)
    mean_out, std_out = _mom(x_out)

    x_out = (x_out - mean_out) / std_out * std_in + mean_in
    return x_out

class HIT2dDataset(Dataset):
    def __init__(self, 
        path: Optional[str] = None,
        base_path: Optional[str] = None,
        dataset_name: Optional[str] = None,
        split_name: Optional[str] = None,

        target_path: Optional[str] = None,
        target_base_path: Optional[str] = None,
        target_dataset_name: Optional[str] = None,
        
        normalization:Optional[callable] = None,
        transform:Optional[callable] = None,

        n_input: int = 1,
        n_output: int = 1,
        n_stride: int = 0,
        max_rollout_steps=100,
        max_n_sim: Optional[int] = None,
        # batch_size:int=32, 
        
        # num_iteration_per_data:int=None, 
        # isDataAugmentation:bool=False,
        load_data:bool=True,
        verbose:bool=True,
        **kwargs, 
        ):
        super().__init__()
        self.path = path
        self.base_path = base_path
        self.dataset_name = dataset_name
        self.split_name = split_name
        if not path: 
            path = os.path.join(self.base_path, self.split_name, self.dataset_name) # f"{self.base_path}/{self.split_name}/{self.dataset_name}*"
        self.path = sorted(glob.glob(f"{path}*"))
        assert self.path, f"Error: Dataset path {path} does not exist."

        self.target_path = self.path
        if target_path is not None or target_base_path is not None:
            self.target_base_path = target_base_path
            self.target_dataset_name = target_dataset_name
            self.target_split_name = split_name
            if not target_path: 
                target_path = os.path.join(self.target_base_path, self.split_name, self.target_dataset_name) # f"{self.base_path}/{self.split_name}/{self.dataset_name}*"
            self.target_path = sorted(glob.glob(f"{target_path}*"))
            assert self.target_path, f"Error: Dataset path {target_path} does not exist."

        self.normalization = normalization
        self.transform = transform

        self.max_rollout_steps = max_rollout_steps
        self.n_input = n_input
        self.n_output = n_output
        self.n_stride = n_stride
        self.max_n_sim = max_n_sim if max_n_sim else np.inf
        
        self.verbose=verbose
        self.kwargs = kwargs

        self._build_metadata()
        self._calc_len()
        self.load_data = load_data
        if load_data:
            self.data = self._load_data(self.path)
            self.target_data = self._load_data(self.target_path) if self.target_path is not None else self.data

    def _calc_len(self):
        self.n_sim = min(self.max_n_sim, sum(self.n_sim_per_file)) # self.n_sim = sum(self.n_sim_per_file) # len(self.data)
        self.n_steps_per_sim = self.Nt
        self.n_windows_per_sim = self.n_steps_per_sim - (self.n_input + self.n_output + self.n_stride) + 1
        self.len = self.n_sim * self.n_windows_per_sim

    def __len__(self):
        return self.len
    
    def __getitem__(self, idx:int):
        data = self._load_one_sample(idx)
        data = self._preprocess_data(data)
        if self.transform:
            data = self.transform(data)
        if self.normalization:
            data = self.normalization(data)      
        data = self._postprocess_data(data)
        return data
    
    def _load_one_sample(self, idx:int):
        isim = idx // self.n_windows_per_sim
        it = idx % self.n_windows_per_sim
        if self.load_data:
            data = self.data[isim, it:it + self.n_input] # (n_in, channels, *datashape)
            target = self.target_data[isim, it + self.n_input + self.n_stride:it + self.n_input + self.n_stride + self.n_output] # (n_out, channels, *datashape)
        else: 
            for i in range(len(self.n_sim_per_file)):
                ifile = i 
                if isim < sum(self.n_sim_per_file[:i+1]): break 
            
            isim = isim - sum(self.n_sim_per_file[:i])
            with h5py.File(self.path[ifile], 'r') as f:
                data = f["fields"]['vorticity'][isim, it:it + self.n_input] # (n_in, channels, *datashape)
                data = torch.from_numpy(data)
                
            with h5py.File(self.target_path[ifile], 'r') as f:
                target = f["fields"]['vorticity'][isim, it + self.n_input + self.n_stride:it + self.n_input + self.n_stride + self.n_output] # (n_out, channels, *datashape)
                target = torch.from_numpy(target)
        return data, target

    def _load_data(self, path=None,):
        if not path: path = self.path
        data = []
        self.n_sim = 0
        if self.verbose: start_time = time.time()
        for p in path:
            with h5py.File(p, 'r') as f:
                _n_sim = f.attrs["n_trajectories"]
                
                end = min(self.max_n_sim - self.n_sim, _n_sim)
                vorticity = f["fields"]['vorticity'][:end]# vorticity = f["fields"]['vorticity'][:]
                if self.verbose: 
                    print(f'Data Loaded from{p}. shape: {vorticity.shape}, memory: {vorticity.nbytes/1e6} (MB)')
            
            data.append(vorticity)
            self.n_sim += len(vorticity)
            if self.max_n_sim <= self.n_sim: break
            
        data = np.concatenate(data, axis=0)
        memory = data.nbytes
        data = torch.from_numpy(data)
        # self.data = self.data.reshape(self.data.size(0) * self.data.size(1), self.data.size()[2:]) # self.data = torch.concat(self.data, dim=0)
        if self.verbose: print(f'Data Loaded. shape: {data.shape}, dtype: {data.dtype}, time: {time.time() - start_time} (sec), memory: {memory/1e6} (MB)')
        self._calc_len()
        return data
        
    def _build_metadata(self):
        
        with h5py.File(self.path[0], 'r') as f:
            self.nu = f['scalars']['nu'][()]

            self.t = t = f["dimensions"]["t"][:]
            self.x = x = f["dimensions"]["x"][:]
            self.y = y = f["dimensions"]["y"][:]
            self.X, self.Y = np.meshgrid(x, y)

            self.Lx = x[-1] - x[0]
            self.Ly = y[-1] - y[0]
            self.dx = x[1] - x[0]
            self.dy = y[1] - y[0]
            self.Nx = len(x)
            self.Ny = len(y)
            self.dt = t[1] - t[0]
            self.Nt = len(t)
            self.T = t[-1]
            self.t0 = t[0]

        self.n_sim_per_file = []
        for path in self.path:
            with h5py.File(path, 'r') as f:
                n_sim = f.attrs["n_trajectories"]
                self.n_sim_per_file.append(n_sim)
                if self.max_n_sim <= sum(self.n_sim_per_file): break
        
        self.metadata = {
            'nu': self.nu,

            'Lx': self.Lx,
            'Ly': self.Ly,
            'dx': self.dx,
            'dy': self.dy,
            'Nx': self.Nx, 
            'Ny': self.Ny,
            'x': self.x,
            'y': self.y,

            'dt': self.dt,
            'Nt': self.Nt,
            'T': self.T,
            't0': self.t0,
            't': self.t,
        }
        if self.verbose:
            print(f"Loaded dataset metadata: {self.metadata.keys()}")
        return self.metadata
    
    def _preprocess_data(self, data):
        data, target = data

        x = [data, target] # x = torch.cat([data, target], dim=0)

        return x
    def _postprocess_data(self, data):
        data, target = data[0], data[1]

        # data = data.flatten(0, 1) # data.reshape(data.shape[0] * data.shape[1], *data.shape[2:]) # (n_in * channels, *datashape)
        # target = target.flatten(0, 1) # target.reshape(target.shape[0] * target.shape[1], *target.shape[2:]) # (n_out * channels, *datashape)

        return data, target

class CustomDataModule:
    def __init__(self, 
                 batch_size:int=32, 
                 Ndata_train:int=500, 
                 Ndata_val:int=50, 
                 Ndata_test:int=100, 
                 num_workers:int=0,
                 transform:Optional[callable]=None,
                 normalization:Optional[callable]=None,
                 *args, **kwargs
                 ):
        self.batch_size = batch_size
        self.Ndata_train = Ndata_train
        self.Ndata_val = Ndata_val
        self.Ndata_test = Ndata_test

        self.num_workers = num_workers

        self.transform = transform
        self.normalization = normalization

        self.args = args
        self.kwargs = kwargs

    
    def setup(self, stage:str=None):
        if stage in ['train', 'fit', None]:
            self.train_dataset = HIT2dDataset(split_name='train', max_n_sim=self.Ndata_train, transform=self.transform, normalization=self.normalization, *self.args, **self.kwargs, )
            print(f'Train dataset: {len(self.train_dataset)}')
            
        if stage in ['valid', 'fit', None]:
            self.val_dataset = HIT2dDataset(split_name='valid', max_n_sim=self.Ndata_val, transform=None, normalization=self.normalization, *self.args, **self.kwargs, )
            print(f'Val dataset: {len(self.val_dataset)}')

        if stage in ['test', None]:
            self.test_dataset = HIT2dDataset(split_name='valid', max_n_sim=self.Ndata_test, transform=None, normalization=self.normalization, *self.args, **self.kwargs, )
            print(f'Test dataset: {len(self.test_dataset)}')
    
    def prepare_data(self):
        pass

    def train_dataloader(self):
        if self.train_dataset is None: self.setup(stage='train')
        return torch.utils.data.DataLoader(self.train_dataset, 
                                           batch_size = self.batch_size, 
                                           shuffle=True, 
                                           num_workers=self.num_workers, 
                                           pin_memory=True,
                                           drop_last=True, 
                                           )
    
    def val_dataloader(self):
        if self.val_dataset is None: self.setup(stage='valid')
        return torch.utils.data.DataLoader(self.val_dataset, 
                                           batch_size = self.batch_size, 
                                           shuffle=False, 
                                           num_workers=self.num_workers,
                                           pin_memory=True,
                                           drop_last=True,
                                           )

    def test_dataloader(self):
        if self.test_dataset is None:self.setup(stage='test')
        test_batch_size = self.batch_size # self.test_dataset.n_windows_per_sim # // 8
        return torch.utils.data.DataLoader(self.test_dataset, 
                                           batch_size = test_batch_size, 
                                           shuffle=False, 
                                           num_workers=self.num_workers, 
                                           pin_memory=True,
                                           drop_last=True,
                                           )


class RandomShift(nn.Module):
    def __init__(self, 
        shifts:Union[float, Sequence[float]]=(0.5, 0.5), 
        dims:Union[int, Sequence[int]]=(-2, -1),
        *args, **kwargs
        ):
        super(RandomShift, self).__init__()
        self.shifts = shifts
        self.dims = dims

    def forward(self, 
        x, 
        shifts:Union[float, Sequence[float]]=None, 
        dims:Union[int, Sequence[int]]=None, 
        ):
        if shifts is None: shifts = self.shifts
        if dims is None: dims = self.dims
        
        return self.RandomShift(x, shifts=self.shifts, dims=self.dims)
    
    def RandomShift(self,
        x, 
        shifts:Union[float, Sequence[float]]=(0.5, 0.5), 
        dims:Union[int, Sequence[int]]=(-2, -1)
        ):
        '''
        Random shift the input tensor along the spatial dimensions. 
        Parameters:
        - x (torch.tensor): input tensor, shape 
        - shifts (float or sequence of float): maximum shift fraction along each dimension. 
            If float, the same shift fraction is applied to all dimensions.
            If sequence of float, the length must be equal to the number of spatial dimensions.
            The shift fraction is relative to the size of the dimension.
        - dims (int or sequence of int): dimensions to apply the shift. 
            If int, the same dimension is applied to all spatial dimensions.
            If sequence of int, the length must be equal to the number of spatial dimensions.
        
        Returns:
        - x (torch.tensor): shifted tensor, shape 
        '''
        if isinstance(shifts, float): shifts = [shifts] * len(dims)
        assert len(shifts) == len(dims), "Length of shifts and dims must be equal to the number of spatial dimensions."

        shifts = [int(np.random.uniform(-s, s) * x.shape[d]) for s, d in zip(shifts, dims)]
        x = torch.roll(x, shifts=shifts, dims=dims)
        return x

class Reflection(nn.Module):
    def __init__(self, 
        p: float=0.5,
        dims:Union[int, Sequence[int]]=(-2, -1),
        *args, **kwargs
        ):
        super().__init__()
        self.p = p
        self.dims = dims

    def forward(self, 
        x, 
        p = None, 
        dims:Union[int, Sequence[int]]=None, 
        ):
        if dims is None: dims = self.dims
        if p is None: p = self.p
        
        return self.Reflection(x, p=p, dims=self.dims)
    
    def Reflection(self,
        x, 
        p: float=0.5, 
        dims:Union[int, Sequence[int]]=(-2, -1)
        ):
        '''
        Reflection the input tensor along the spatial dimensions. 
        Parameters:
        - x (torch.tensor): input tensor, shape 
        - dims (int or sequence of int): dimensions to apply the Reflection. 
            If int, the same dimension is applied to all spatial dimensions.
            If sequence of int, the length must be equal to the number of spatial dimensions.
        
        Returns:
        - x (torch.tensor): Reflection tensor, shape 
        '''
        if isinstance(dims, int): dims = [dims]
        
        for d in dims:
            if np.random.rand() < p: x = torch.flip(x, dims=(d,))
        return x

class Reverse(nn.Module):
    def __init__(self, 
        p: float=0.5,
        *args, **kwargs
        ):
        super().__init__()
        self.p = p

    def forward(self, 
        x, 
        p = None, 
        ):
        if p is None: p = self.p
        
        return self.Reverse(x, p=p)
    
    def Reverse(self,
        x, 
        p: float=0.5, 
        ):
        '''
        Reverse the input tensor along the time dimension. 
        Parameters:
        - x (torch.tensor): input tensor, shape 
        - dim (int): dimension to apply the Reverse. 
        
        Returns:
        - x (torch.tensor): Reverse tensor, shape 
        '''
        if np.random.rand() < p: x = -x 
        return x
    
class CustomTransform(nn.Module):
    def __init__(self, 
                *args, **kwargs, 
                ):
        super(CustomTransform, self).__init__()
        # self.file_dir = file_dir
        # self.Ndata = Ndata
        self.kwargs = kwargs 
        self.build_transform()

    def build_transform(self):
        self.transforms = nn.Sequential(
            RandomShift(shifts=0.5, dims=(-2, -1)),
            Reflection(p=0.5, dims=(-2, -1)),
            Reverse(p=0.5),
        )
        return self.transforms
    
    def forward(self, x):
        x = self.transform(x)
        return x
    
    def transform(self, x):
        for t in self.transforms:
            x = t(x)
        return x

    def inverse_transform(self, x):
        for t in self.transforms[::-1]:
            x = t.inverse_transform(x)
        return x

class Standardize(nn.Module):
    def __init__(self, 
        mean:Optional[float]=None, std:Optional[float]=None, 
        normalization_path:Optional[str]=None,

        base_path: Optional[str] = None,
        dataset_name: Optional[str] = None,
        split_name: Optional[str] = None,
        *args, **kwargs, 
        ):
        super(Standardize, self).__init__()
        self.args = args
        self.kwargs = kwargs

        ## build_transform
        if mean is not None and std is not None: 
            self.mean = mean
            self.std = std
        elif normalization_path is not None:
            self.load_stats(normalization_path)
        elif dataset_name is not None: 
            path = os.path.join(base_path, split_name, dataset_name)
            self.calc_stats_from_dataset(path)
        else:
            assert False, "Error: must provide either (mean, std), normalization_path, or (base_path, dataset_name, split_name) to calculate mean and std."
    
    def load_stats(self, path=None):
        assert os.path.exists(path), f"Error: normalization path {path} does not exist."
        with open(path, "r") as f:
            stats = yaml.safe_load(f)['statistics']

            self.mean = stats['mean']
            self.std = stats['std']
        return self.mean, self.std
    
    def calc_stats_from_dataset(self, path=None):
        dataset = HIT2dDataset(path=path, load_data=True, *self.args, **self.kwargs).load_data()
        self.mean = dataset.data.mean().item()
        self.std = dataset.data.std().item()
        return self.mean, self.std

    def forward(self, x):
        data, target = x
        data = self.normalize(data, params=None)[0]
        target = self.normalize(target, params=None)[0]
        return [data, target]
    
    def normalize(self, inpt: Any, params: Dict[str, Any]):
        return (inpt - self.mean) / self.std

    def denormalize(self, inpt: Any, params: Dict[str, Any]):
        return inpt * self.std + self.mean

In [6]:
from neuraloperator import FNO 

In [7]:
import torch
import torch_optimizer 
from utils.refer.soap import SOAP

SCHEDULERS = {
    'StepLR': lambda optimizer, **params: torch.optim.lr_scheduler.StepLR(
        optimizer, **(params | {'step_size': 10, 'gamma': 0.1})
    ),
    'ExponentialLR': lambda optimizer, **params: torch.optim.lr_scheduler.ExponentialLR(
        optimizer, **(params | {'gamma': 0.95})
    ),
    'MultiStepLR': lambda optimizer, **params: torch.optim.lr_scheduler.MultiStepLR(
        optimizer, **(params | {'milestones': [30, 80], 'gamma': 0.1})
    ),
    'CosineAnnealingLR': lambda optimizer, **params: torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, **(params | {'T_max': 10})
    ),
    'CosineAnnealingWarmRestarts': lambda optimizer, **params: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, **(params | {'T_0': 10, 'T_mult': 2})
    ),
    'OneCycleLR': lambda optimizer, **params: torch.optim.lr_scheduler.OneCycleLR(
        optimizer, **(params | {'max_lr': 0.1, 'total_steps': 100})
    ),
    'ReduceLROnPlateau': lambda optimizer, **params: torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, **(params | {'mode': 'min', 'factor': 0.1, 'patience': 10})
    ),
}


OPTIMIZERS = {
    'SGD': lambda model, **params: torch.optim.SGD(
        model, **({'lr': 0.01, 'momentum': 0.9} | params)
    ),
    'Adam': lambda model, **params: torch.optim.Adam(
        model, **({'lr': 1e-3, 'betas': (0.9, 0.999)} | params)
    ),
    'AdamW': lambda model, **params: torch.optim.AdamW(
        model, **({'lr': 1e-3, 'weight_decay': 1e-2} | params)
    ),
    'RMSprop': lambda model, **params: torch.optim.RMSprop(
        model, **({'lr': 1e-2, 'alpha': 0.99, 'momentum': 0.9} | params)
    ),
    'Adagrad': lambda model, **params: torch.optim.Adagrad(
        model, **({'lr': 1e-2} | params)
    ),
    'Adadelta': lambda model, **params: torch.optim.Adadelta(
        model, **({'lr': 1.0, 'rho': 0.9} | params)
    ),
    'Adamax': lambda model, **params: torch.optim.Adamax(
        model, **({'lr': 2e-3} | params)
    ),
    'NAdam': lambda model, **params: torch.optim.NAdam(
        model, **({'lr': 2e-3, 'betas': (0.9, 0.999)} | params)
    ),
    'Shampoo': (
        (lambda model, **params: torch_optimizer.Shampoo(
            model, **({'lr': 1e-3, 'momentum': 0.9, 'weight_decay': 0.0} | params)
        )) 
    ),
    'SOAP': (
        (lambda model, **params: SOAP(
            model, **({'lr': 1e-3,} | params)
        )) 
    ),
}


In [8]:

class CustomModule(nn.Module):
    def __init__(self, 
                 model:Optional[callable]= None,
                 modelconfig:Optional[dict]=None, 
                 optconfig:Optional[dict]=None, 
                 
                 loss_fn: Optional[callable] = nn.MSELoss(),
                 transformer:Optional[callable]=None, 
                 **kwargs, 
                 ):
        super().__init__()
        
        self.modelconfig = modelconfig
        self.optconfig = optconfig
        self.hparams = kwargs
        self.device = device

        self.model = self.build_model() if model is None else model.to(device)

        self.optimizer = None
        self.scheduler = None
        if optconfig is not None: self.configure_optimizers()

        self.loss_fn = None 
        if optconfig is not None: 
            self.loss_fn = loss_fn
        self.transformer = transformer

    def build_model(self):
        self.model = FNO(**self.modelconfig['params'],).to(self.device)

        return self.model
    
    def configure_optimizers(self):
        if self.optimizer is not None: 
            return self.optimizer, self.scheduler
        self.optimizer = OPTIMIZERS[self.optconfig['name']](
            self.model.parameters(), 
            **self.optconfig['params']
        )

        self.scheduler = None
        if self.optconfig['scheduler']: 
            self.scheduler = SCHEDULERS[self.optconfig['scheduler']['name']](
                self.optimizer, 
                **self.optconfig['scheduler'].get('params', {}),
                ) 
        return self.optimizer, self.scheduler
    
    def forward(self, data, output_shape=None):
        return self.model(data.to(self.device), output_shape=output_shape)
    
    def training_step(self, batch, batch_idx, ret_log:bool=False):
        if self.transformer: batch = self.transformer(batch)
        data, target = batch
        
        # data = data[:, 0]
        # target = target[:, 0]
        
        self.optimizer.zero_grad()
        output = self.forward(data.to(self.device), output_shape=target.shape[2:])
        loss = self.loss_fn(output, target.to(self.device))
        loss.backward()

        self.optimizer.step()
        
        self.log('loss', loss.item(), on_step=True)
        return loss
    
    @torch.no_grad()
    def validation_step(self, batch, batch_idx, ret_log:bool=True):
        if self.transformer: batch = self.transformer(batch)
        data, target = batch

        # data = data[:, 0]
        # target = target[:, 0]

        output = self.forward(data.to(self.device), output_shape=target.shape[2:])
        loss = self.loss_fn(output, target.to(self.device))

        self.log('val_loss', loss.item(), on_step=False)
        if ret_log: 
            log = {}
            log.update({"loss": loss.item(), "output": output[:].detach().cpu(), 'batch': [data, target.detach().cpu()]})
            return log
        return output

    @torch.no_grad()
    def test_step(self, batch, batch_idx, ret_log:bool=True):
        if self.transformer: batch = self.transformer(batch)
        
        data, target = batch

        # data = data[:, 0]
        # target = target[:, 0]

        output = self.forward(data.to(self.device), output_shape=target.shape[2:])

        if ret_log: 
            log = {}
            log.update({# "loss": loss.item(), 
                        "output": output[:].detach().cpu(), 'batch': [data.detach().cpu(), target.detach().cpu()]})
            return log
        return output

    def save(self, path, checkpoint:dict={}):
        checkpoint['model_state_dict'] = self.model.state_dict()
        checkpoint['optimizer_state_dict'] = self.optimizer.state_dict()
        if self.scheduler: checkpoint[f'scheduler'] = self.scheduler.state_dict()

        torch.save(checkpoint, path + '.pth')

    def load(self,
        path, 
        verbose=True, 
        load_opt:bool=True,
        freeze:bool=False,
        ):
        if os.path.exists(path + '.pth') and os.path.getsize(path + '.pth'):
            checkpoint = torch.load(path + '.pth', weights_only=False)

            self.model.load_state_dict(
                checkpoint['model_state_dict'], 
                strict=False,
                )
            if freeze: 
                for name, param in self.model.named_parameters():
                    if name in checkpoint['model_state_dict']:
                        param.requires_grad = False
            if verbose: print(f'model loaded from {path}')
            if load_opt:
                if self.optimizer and 'optimizer_state_dict' in checkpoint: 
                    self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                    if verbose: print(f'optimizer loaded from {path}')
                if self.scheduler and 'scheduler' in checkpoint: 
                    self.scheduler.load_state_dict(checkpoint['scheduler'])
                    if verbose: print(f'scheduler loaded from {path}')
    
    def log(self, name, value, prog_bar=False, logger=None, on_step=None, on_epoch=None, reduce_fx='mean', enable_graph=False, sync_dist=False, sync_dist_group=None, add_dataloader_idx=True, batch_size=None, metric_attribute=None, rank_zero_only=False):
        commit = False if on_step is False else True
        wandb.log({name: value}, commit=commit)

In [9]:
from utils.BaseCallback import BaseCallback
from utils.CFDFunction import calc_energy_spectrum, calc_pdf, calc_phase_error
from utils.Plots import plot_2d_surface, plot_spectrum, plot_pdf

class CustomCallback(BaseCallback):
    def __init__(self, 
                 criterions, 
                 file_dir:str='', 
                 fname:str ='',
                 device:str =None, 
                 val_every_n_iter:int=None,
                 ):
        super(CustomCallback, self).__init__()
        self.criterions = criterions # DL criterion + CFD criterion + DL part-of-loss
        self.file_dir = file_dir
        self.fname = fname
        self.device=device
        
    def on_validation_epoch_start(self, trainer, pl_module):
        self.time_stamp = time.time()
        trainer._current_val_return = {'loss': [], 'time': []}
        for key in self.criterions.keys():
            trainer._current_val_return[key] = []

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
        data, target = outputs['batch'] # batch
        # trainer._current_val_return['batch'] = outputs['batch'] # batch
        # trainer._current_val_return['output'] = outputs['output']
        # trainer._current_val_return['loss'].append(outputs['loss'])

        for key in self.criterions.keys():
            value = self.criterions[key](outputs['output'].to(self.device), target.to(self.device)).item()

            trainer._current_val_return[key].append(value)
            
        self.N = target.shape[-1]
    
    def on_validation_end(self, trainer, pl_module, fname='') -> None:
        fname = self.fname + fname
        # data, target = trainer._current_val_return['batch']
        # output = trainer._current_val_return['output']
        # if trainer.logger: trainer.logger.log({'val_loss': np.mean(trainer._current_val_return['loss'])})

        epoch = trainer.current_epoch
        print(f'  Validation Epoch: {epoch}', end='')
        for key in self.criterions.keys():
            value = np.mean(trainer._current_val_return[key])
            print(f", {key}: {value:.6f}", end='')

            if trainer.logger: trainer.logger.log({f'Val_{self.N}/epoch'+key: value})
        print()

        plt.close('all')

     


    def on_test_batch_start(self, trainer, pl_module, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
        self.time_stamp = time.time()

    def on_test_batch_end(self, trainer, pl_module, outputs, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
        data, target = batch
        trainer._current_test_return['batch'] = batch
        output = trainer._current_test_return['output'] = outputs['output']
        
        runtime = time.time() - self.time_stamp
        if trainer.logger: trainer.logger.log({'Test/time': runtime})
        print(f'Test sample {batch_idx}(time: {runtime:.6f}, batch size: {len(output)})', end='')
        for key in self.criterions.keys():
            value = self.criterions[key](outputs['output'].to(self.device), target.to(self.device)).item()

            trainer._current_test_return[key].append(value)
            print(f", {key}: {value:.6f}", end='')
        print()
        
        ## enstrophy spectrum
        spectrum1 = calc_energy_spectrum(target[:,0,:], dim=2) # np.mean([calc_enstrophy_spectrum(target[i,0,:]) for i in range(len(target))], axis=0); 
        spectrum2 = calc_energy_spectrum(output[:,0,:], dim=2) # np.mean([calc_enstrophy_spectrum(output[i,0,:]) for i in range(len(output))], axis=0)
        ## vorticity pdf
        bin1, pdf1 = calc_pdf(target[:,0,:], dim=2) # np.mean([calc_pdf(target[i,0,:]) for i in range(len(target))], axis=0); 
        bin2, pdf2 = calc_pdf(output[:,0,:], dim=2) # np.mean([calc_pdf(output[i,0,:]) for i in range(len(output))], axis=0)

        # ### solution field 
        plot_2d_surface(target[0,0,...,0], output[0,0,...,0], # axes=[ax1,ax2], 
                        kwargs={'figsize': (8, 4)}
                        )
        plt.savefig(self.file_dir + self.fname + f'_Test_sample={batch_idx}_field.png')
        # plt.show()
        plt.close()

        fig = plt.figure(figsize=(8,4))
        ax = fig.add_subplot(1, 2, 1)
        plot_spectrum(spectrum1, spectrum2, ax=ax, kwargs={'title': f'enstrophy spectrum'})

        ax = fig.add_subplot(1, 2, 2)
        plot_pdf(bin1, pdf1, bin2, pdf2, ax=ax, kwargs={'title': f'pdf'})
        plt.tight_layout()
        plt.savefig(self.file_dir + self.fname + f'_Test_sample={batch_idx}_stat.png')
        # plt.show()
        plt.close()

        ## save .wandb
        Nt = len(output)
        for it in range(Nt):
            image = np.concatenate((target[it,0,:], output[it,0,:]), axis=1)
            if trainer.logger: trainer.logger.log({'Test/'+f'sample_{batch_idx}/'+'field': wandb.Image(image)})
        if trainer.logger: trainer.logger.log({'Test/spectrum': wandb.plot.line_series(
            xs=[np.arange(0, spectrum1.shape[-1]+1), np.arange(0, spectrum2.shape[-1]+1)], 
            ys = [spectrum1, spectrum2],
            keys = ['ground truth', 'prediction'],
            xname = ['k', 'enstrophy spectrum'],
            title='enstrophy_spectrum', 
            )})
        if trainer.logger: trainer.logger.log({'Test/pdf': wandb.plot.line_series(
            xs=[bin1, bin2], 
            ys = [pdf1, pdf2],
            keys = ['ground truth', 'prediction'],
            xname = ['w', 'pdf'],
            title='pdf', 
            )})

    def on_test_end(self, trainer, pl_module) -> None:
        data, target = trainer._current_test_return['batch']
        output = trainer._current_test_return['output']

        print('Test ')
        for key in self.criterions.keys():
            value = np.mean(trainer._current_test_return[key])
            print(f", {key}: {value:.6f}", end='')

            if trainer.logger: trainer.logger.log({'Test/'+key: value})
        print()

    def preprocessing(self, x):
        n_modes = self.n_modes
        x = self.split(x, modes=n_modes)
        # x = self.scaling(x, dim=2, isNormalize=False)
        return x 
    
    def postprocessing(self, x):
        n_modes = self.n_modes
        # x = self.unscaling(x)
        x = self.unsplit(x, modes=n_modes)
        return x
    
    def scaling(self, x, 
                dim=1, 
                isNormalize:bool=False):
        dim = list(range(-dim, 0))
        if isNormalize:
            self.mean = x.amin(dim=dim, keepdim=True)
            self.std = x.amax(dim=dim, keepdim=True) - self.mean
        else: 
            self.mean = x.mean(dim=dim, keepdim=True)
            self.std = x.std(dim=dim, keepdim=True)
        return (x - self.mean) / self.std
    
    def unscaling(self, x):
        return x * self.std + self.mean
    
    def split(self, x, 
              modes=None
              ):

        b, c, *data_shape = x.shape
        fft_dims = list(range(-len(data_shape), 0))
        Fx = torch.fft.fft2(x, dim=fft_dims, norm='forward')
        Fx_ = torch.zeros((len(modes), b, c, *data_shape), dtype=Fx.dtype, device=x.device)
        
        k = [torch.fft.fftfreq(n, d=1./n) for n in data_shape]
        # k += [torch.fft.rfftfreq(data_shape[-1], d=1./data_shape[-1])] if Fx.dtype in [torch.float16, torch.float32, torch.float64] else [torch.fft.fftfreq(data_shape[-1], d=1./data_shape[-1])]


        k = torch.meshgrid(k, indexing='ij')
        k = torch.stack(k)
        k = torch.sqrt(torch.sum(k**2, axis=0)).to(x.device)
        for i in range(len(modes)):
            k1, k2 = modes[i]
            idx = (k1 <= k) & (k <= k2)
            Fx_[i] = Fx * idx

        x_ = torch.fft.ifft2(Fx_, dim=fft_dims, norm='forward').real # x_ = Fx_ # # 
        
        return x_

    def unsplit(self, x_, 
              modes=None
              ):
        n, b, c, *data_shape = x_.shape
        fft_dims = list(range(-len(data_shape), 0))
        
        Fx_ = torch.fft.fft2(x_, dim=fft_dims, norm='forward') # Fx_ = x_ # 
        
        Fx = torch.zeros((b, c, *data_shape), dtype=Fx_.dtype, device=x_.device)
        k = [torch.fft.fftfreq(n, d=1./n) for n in data_shape[:-1]]
        k += [torch.fft.rfftfreq(data_shape[-1], d=1./data_shape[-1])] if Fx.dtype in [torch.float16, torch.float32, torch.float64] else [torch.fft.fftfreq(data_shape[-1], d=1./data_shape[-1])]
        k = torch.meshgrid(k, indexing='ij')
        k = torch.stack(k)
        k = torch.sqrt(torch.sum(k**2, axis=0)).to(x_.device)
        for i in range(len(modes)):
            k1, k2 = modes[i]
            idx = (k1 <= k) & (k <= k2)
            Fx[..., idx] = Fx_[i, ..., idx]

        x = torch.fft.ifft2(Fx, dim=fft_dims, norm='forward').real
        return x


In [10]:
##### set parameters #####
### save parameters ### 
data_dir = config['data'][0]['params']['base_path'] # '../Data/nu=0.001_n=128/'
data_fname = config['data'][0]['params']['dataset_name'] # f'2dHIT_nu=0.001_n=128_T=11.5'

metadata = HIT2dDataset(path=data_dir + 'train/' + data_fname, load_data=False).metadata
print(metadata)

nu = metadata['nu']
dt = metadata['dt']
N = metadata['Nx']

Loaded dataset metadata: dict_keys(['nu', 'Lx', 'Ly', 'dx', 'dy', 'Nx', 'Ny', 'x', 'y', 'dt', 'Nt', 'T', 't0', 't'])
{'nu': 0.000225, 'Lx': 1.5402125848364265, 'Ly': 1.5402125848364265, 'dx': 0.024447818806927406, 'dy': 0.024447818806927406, 'Nx': 64, 'Ny': 64, 'x': array([0.        , 0.02444782, 0.04889564, 0.07334346, 0.09779128,
       0.12223909, 0.14668691, 0.17113473, 0.19558255, 0.22003037,
       0.24447819, 0.26892601, 0.29337383, 0.31782164, 0.34226946,
       0.36671728, 0.3911651 , 0.41561292, 0.44006074, 0.46450856,
       0.48895638, 0.51340419, 0.53785201, 0.56229983, 0.58674765,
       0.61119547, 0.63564329, 0.66009111, 0.68453893, 0.70898675,
       0.73343456, 0.75788238, 0.7823302 , 0.80677802, 0.83122584,
       0.85567366, 0.88012148, 0.9045693 , 0.92901711, 0.95346493,
       0.97791275, 1.00236057, 1.02680839, 1.05125621, 1.07570403,
       1.10015185, 1.12459967, 1.14904748, 1.1734953 , 1.19794312,
       1.22239094, 1.24683876, 1.27128658, 1.2957344 , 1.320182

In [11]:
from utils.Losses import (RMSLoss, TKELoss, DissipationLoss, RelambdaLoss, R2, 
                          BSMSE, )
from utils.utilities import HsLoss


test_criterion={
    'l2': nn.MSELoss(), 
    'h1': H1Loss(d=2, reduction='mean'), 
    'h2': HsLoss(k=2, reduction='mean'), # HsLoss(k=2, group=False, size_average=True), 
    
    'fRMS_k<8': BSMSE(kmax=8, dim=(-2, -1), mode='spectral', isRelative=True),
    'fRMS_8<k<16': BSMSE(kmin=8, kmax=16, dim=(-2, -1), mode='spectral', isRelative=True),
    'fRMS_k>16': BSMSE(kmin=16, dim=(-2, -1), mode='spectral', isRelative=True),
    'fRMS_k<kmax': BSMSE(kmax=16, dim=(-2, -1), mode='spectral', isRelative=True),
    'fRMS_k>kmax': BSMSE(kmin=16, dim=(-2, -1), mode='spectral', isRelative=True),
    'fRMS_k>train': BSMSE(kmin=32, dim=(-2, -1), mode='spectral', isRelative=True),
    'fRMS_k<train': BSMSE(kmax=32, dim=(-2, -1), mode='spectral', isRelative=True),

    'vor_rms': RMSLoss(dim=2, isRelative=True),
    # 'tke': TKELoss(dim=2, isRelative=True), 
    # 'dissipation': DissipationLoss(nu=nu, dim=2, isRelative=True), 
    # 'R_lambda': RelambdaLoss(nu=nu, dim=2, isRelative=True), 
    'R_squared': R2(), 
    }

##### setup dataset #####
dataloaders = {'train':[], 'valid': [], 'test': []}
for i, data_config in enumerate(config['data']):
    transform = CustomTransform()
    normalization = Standardize(normalization_path=data_config['normalization']['normalization_path'])

    idx_leadtime = data_config['idx_leadtime'] if 'idx_leadtime' in data_config else int(data_config['leadtime'] * integral_timescales[nu] / dt) 
    dm = CustomDataModule(**data_config['params'], 
        n_stride=idx_leadtime, 
        # Ndata_train=500, # 500, 
        # Ndata_val=50, 
        # Ndata_test=50,
        # transform=transform,
        normalization=normalization,
        load_data=False,
        )
    stage = data_config['stage']
    dm.setup(stage=stage)
    if stage in ['fit', 'valid']:
        dataloaders['valid'].append(dm.val_dataloader())
    if stage in ['fit', 'train']:
        dataloaders['train'].append(dm.train_dataloader())
    if stage in ['test']:
        dataloaders['test'].append(dm.test_dataloader())


Loaded dataset metadata: dict_keys(['nu', 'Lx', 'Ly', 'dx', 'dy', 'Nx', 'Ny', 'x', 'y', 'dt', 'Nt', 'T', 't0', 't'])
Train dataset: 182500
Loaded dataset metadata: dict_keys(['nu', 'Lx', 'Ly', 'dx', 'dy', 'Nx', 'Ny', 'x', 'y', 'dt', 'Nt', 'T', 't0', 't'])
Val dataset: 18250
Loaded dataset metadata: dict_keys(['nu', 'Lx', 'Ly', 'dx', 'dy', 'Nx', 'Ny', 'x', 'y', 'dt', 'Nt', 'T', 't0', 't'])
Val dataset: 365


In [12]:
# from utils.neuraloperator import FNO
from torchinfo import summary

model = FNO(**config['model']['params'])
pl_module = CustomModule(
        model=model, 
        optconfig=config['optimizer'], 
        )
batch_size = config['data'][0]['params']['batch_size']
print(pl_module.device)
summary(
        pl_module.model,
        input_size=(batch_size, 1, N, N),       # batch_size, in_channels, H, W
        col_names=("input_size", "output_size", "num_params", "trainable"),
        depth=5,                           # how many nested modules to show
        verbose=1
    )


cuda
Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Trainable
FNO                                      [64, 1, 64, 64]           [64, 1, 64, 64]           --                        True
├─GridEmbeddingND: 1-1                   [64, 1, 64, 64]           [64, 3, 64, 64]           --                        --
├─ChannelMLP: 1-2                        [64, 3, 64, 64]           [64, 32, 64, 64]          --                        True
│    └─ModuleList: 2-1                   --                        --                        --                        True
│    │    └─Conv1d: 3-1                  [64, 3, 4096]             [64, 128, 4096]           512                       True
│    │    └─Conv1d: 3-2                  [64, 128, 4096]           [64, 32, 4096]            4,128                     True
├─FNOBlocks: 1-3                         [64, 32, 64, 64]          [64, 32, 64, 64]          1,677,648                 True


Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Trainable
FNO                                      [64, 1, 64, 64]           [64, 1, 64, 64]           --                        True
├─GridEmbeddingND: 1-1                   [64, 1, 64, 64]           [64, 3, 64, 64]           --                        --
├─ChannelMLP: 1-2                        [64, 3, 64, 64]           [64, 32, 64, 64]          --                        True
│    └─ModuleList: 2-1                   --                        --                        --                        True
│    │    └─Conv1d: 3-1                  [64, 3, 4096]             [64, 128, 4096]           512                       True
│    │    └─Conv1d: 3-2                  [64, 128, 4096]           [64, 32, 4096]            4,128                     True
├─FNOBlocks: 1-3                         [64, 32, 64, 64]          [64, 32, 64, 64]          1,677,648                 True
│    

In [13]:
from utils.Trainer import Trainer
from utils.DLCallbacks import CheckPoint, EarlyStopping
from utils.CFDCallback_Spectrum import CFDCallback_Spectrum
from utils.CFDCallback_PhaseError import CFDCallback_PhaseError
from utils.CFDCallback_PDF import CFDCallback_PDF
from utils.CFDCallback_Field import CFDCallback_Field

def main(device='cuda'):
    logger = None 
    logger = wandb.init(project=config['name'], 
                        # name=config['name'], 
                        config=config
                        )
    if 'seed' in config: torch.manual_seed(config['seed'])

    data_config = config['data'][0]
    fname=f"2dHIT_{config['model']['name']}_nu={data_config['nu']}"
    fname += f"_T={data_config['idx_leadtime']}" if 'idx_leadtime' in data_config else f"_T={data_config['leadtime']}TL"
    if logger: fname += f'_{logger.name}'

    optconfig = config['optimizer']
    if 'loss_fn' in config: optconfig['loss_fn'] = config['loss_fn']
    if 'lr' in config: optconfig['params']['lr'] = config['lr']
    modelconfig = {**config['model'], }
    modelconfig['params'].update({k: v for k, v in config.items() if k in config['model']['params']})
    

    model = FNO(**config['model']['params'])
    pl_module = CustomModule(
            model=model, 
            optconfig=config['optimizer'], 
            )
    # pl_module.load('./pretrain/2dHIT_FNO_nu=5e-05_T=0.1TL_different-dust-234_top0', 
    #     # load_opt=False, 
    #     # freeze=True, 
    #     )
    if logger: wandb.watch(pl_module, log="all", log_freq=100)
    callbacks = [
            CFDCallback_Spectrum(file_dir='./result/', fname=fname), 
            CFDCallback_PhaseError(file_dir='./result/', fname=fname), 
            CFDCallback_PDF(file_dir='./result/', fname=fname), 
            CFDCallback_Field(file_dir='./result/', fname=fname), 
            CustomCallback(criterions=test_criterion, file_dir='./result/', fname=fname, device=device,
                        ),
            CheckPoint(
                ckpt_name=fname, ckpt_path='./checkpoint/', 
                every_n_epoch = 1,
                criterion=nn.MSELoss(), mode = 'min', 
                load_ckpt=False, 
                ),
            EarlyStopping(criterion=nn.MSELoss(), # neuralop.H1Loss(d=2, reductions='mean'),# 
                            mode='min', min_delta=1e-5, 
                            patience=5, verbose=True, divergence_threshold=1e3, 
                            stopping_threshold=1e-3, 
                            ),
            ]
    trainer = Trainer(
        max_epochs=config['epochs'],
        check_val_every_n_epochs=1,
        # check_val_every_n_iter=30000, 
        enable_progress_bar=True,
        callbacks=callbacks, 
        logger=logger,
        )
    
    trainer.fit(
        model=pl_module,
        train_dataloaders=dataloaders['train'], 
        val_dataloaders=dataloaders['valid'],
        # datamodule = datamodules,
        )
    if logger: wandb.finish()
    return 

In [14]:
if __name__ == "__main__":
    main()

wandb: Currently logged in as: hjungwon034 (jungwonheo) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


  flat = flat.type(torch.cuda.FloatTensor)
100%|██████████| 2851/2851 [06:29<00:00,  7.32it/s]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
100%|██████████| 285/285 [00:41<00:00,  6.79it/s]


  Validation Epoch: 0, l2: 0.049116, h1: 0.374333, h2: 0.744673, fRMS_k<8: 0.003922, fRMS_8<k<16: 0.034504, fRMS_k>16: 0.353342, fRMS_k<kmax: 0.009065, fRMS_k>kmax: 0.353342, fRMS_k>train: 0.905721, fRMS_k<train: 0.042560, vor_rms: 0.023264, R_squared: 0.950771


100%|██████████| 5/5 [00:04<00:00,  1.13it/s]


  Validation Epoch: 0, l2: 0.108379, h1: 0.796934, h2: 0.993781, fRMS_k<8: 0.005050, fRMS_8<k<16: 0.037208, fRMS_k>16: 0.589371, fRMS_k<kmax: 0.011046, fRMS_k>kmax: 0.589371, fRMS_k>train: 0.938619, fRMS_k<train: 0.055409, vor_rms: 0.029998, R_squared: 0.895725


100%|██████████| 2851/2851 [06:17<00:00,  7.55it/s]
100%|██████████| 285/285 [00:40<00:00,  6.97it/s]


  Validation Epoch: 1, l2: 0.039128, h1: 0.332830, h2: 0.684872, fRMS_k<8: 0.002820, fRMS_8<k<16: 0.026979, fRMS_k>16: 0.278233, fRMS_k<kmax: 0.006879, fRMS_k>kmax: 0.278233, fRMS_k>train: 0.816578, fRMS_k<train: 0.033237, vor_rms: 0.013988, R_squared: 0.960767


100%|██████████| 5/5 [00:04<00:00,  1.22it/s]


  Validation Epoch: 1, l2: 0.095504, h1: 0.771436, h2: 0.984941, fRMS_k<8: 0.003750, fRMS_8<k<16: 0.029875, fRMS_k>16: 0.525552, fRMS_k<kmax: 0.008608, fRMS_k>kmax: 0.525552, fRMS_k>train: 0.917860, fRMS_k<train: 0.043530, vor_rms: 0.025283, R_squared: 0.908083


 10%|▉         | 275/2851 [00:38<05:59,  7.16it/s]


KeyboardInterrupt: 