In [36]:
from typing import Dict

import numpy as np
import torch
import torch.nn.functional as F

import modulus
from modulus.hydra import instantiate_arch, ModulusConfig
from modulus.key import Key
from modulus.models.layers.spectral_layers import fourier_derivatives
from modulus.node import Node

from modulus.solver import Solver
from modulus.domain import Domain
from modulus.domain.constraint import SupervisedGridConstraint
from modulus.domain.validator import GridValidator
from modulus.dataset import DictGridDataset
from modulus.utils.io.plotter import GridValidatorPlotter
from modulus.utils.io.vtk import grid_to_vtk

import hydra
from hydra import compose, initialize

from omegaconf import DictConfig, OmegaConf

from utilities import download_FNO_dataset, load_FNO_dataset
from ops import dx, ddx
from dataset.data import data_generator

In [51]:
class Wave(torch.nn.Module):
    "Custom Wave PDE definition for PINO"

    def __init__(self):
        super().__init__()

    def forward(self, input_var: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        # get inputs
        u = input_var["sol"]
        ic = input_var["IC"]
        c = 1.0
        
        dxf = 1.0 / u.shape[-2]
        dyf = 1.0 / u.shape[-1]
        
        dduddx_exact = input_var["sol__x__x"]
        dduddy_exact = input_var["sol__y__y"]
        # compute wave equation
        wave = (
            1.0
            + (c * dduddx_exact)
            + (c * dduddy_exact)
        )

        # Zero outer boundary
        wave = F.pad(wave[:, :, 2:-2, 2:-2], [2, 2, 2, 2], "constant", 0)
        # Return darcy
        output_var = {
            "wave": dxf * wave,
        }  # weight boundary loss higher
        return output_var

In [85]:
with initialize(config_path="conf"):
    cfg = compose(config_name="config_PINO")
    print(cfg)

{'training': {'max_steps': 10000, 'grad_agg_freq': 1, 'rec_results_freq': 1000, 'rec_validation_freq': '${training.rec_results_freq}', 'rec_inference_freq': '${training.rec_results_freq}', 'rec_monitor_freq': '${training.rec_results_freq}', 'rec_constraint_freq': '${training.rec_results_freq}', 'save_network_freq': 1000, 'print_stats_freq': 100, 'summary_freq': 1000, 'amp': False, 'amp_dtype': 'float16', 'ntk': {'use_ntk': False, 'save_name': None, 'run_freq': 1000}}, 'graph': {'func_arch': False, 'func_arch_allow_partial_hessian': True}, 'stop_criterion': {'metric': None, 'min_delta': None, 'patience': 50000, 'mode': 'min', 'freq': 1000, 'strict': False}, 'profiler': {'profile': False, 'start_step': 0, 'end_step': 100, 'name': 'nvtx'}, 'network_dir': '.', 'initialization_network_dir': '', 'save_filetypes': 'vtk', 'summary_histograms': False, 'jit': False, 'jit_use_nvfuser': True, 'jit_arch_mode': 'only_activation', 'jit_autograd_nodes': False, 'cuda_graphs': False, 'cuda_graph_warmup'

In [76]:
input_keys = [
    Key("IC")
]
output_keys = [
    Key("sol"),
]

In [77]:
dim = 2
N = 128
Nx = 128
Ny = 128
l = 0.1
L = 1.0
sigma = 1.0
Nu = None
dt = 1.0e-4
save_int = int(1e-2/dt)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [64]:
# invar_train, outvar_train = data_generator(dim, N, L, Nu, l, sigma, device, cfg['custom']['ntrain'], Nx, Ny, dt, save_int)
# invar_test, outvar_test = data_generator(dim, N, L, Nu, l, sigma, device, cfg['custom']['ntest'], Nx, Ny, dt, save_int)

print('IC shape: ', invar_test['IC'].shape)
print('Sol shape: ', outvar_test['sol'].shape)

IC shape:  torch.Size([2, 1, 128, 128])
Sol shape:  torch.Size([2, 101, 128, 128])


In [66]:
# outvar_train['sol'] = outvar_train['sol'][:,100:,:,:]
# outvar_test['sol'] = outvar_test['sol'][:,100:,:,:]
print('Sol shape: ', outvar_train['sol'].shape)

Sol shape:  torch.Size([10, 1, 128, 128])


In [67]:
train_dataset = DictGridDataset(invar_train, outvar_train)
test_dataset = DictGridDataset(invar_test, outvar_test)

train_dataset.__getitem__(0)[2]['sol'].shape

torch.Size([1, 128, 128])

In [86]:
decoder_net = instantiate_arch(
        cfg=cfg.arch.decoder,
        output_keys=output_keys,
    )
fno = instantiate_arch(
        cfg=cfg.arch.fno,
        input_keys=[input_keys[0]],
        decoder_net=decoder_net,
    )

In [87]:
inputs = [
    "sol",
    "IC"
]
wave_node = Node(
    inputs=inputs,
    outputs=["wave"],
    evaluate=Wave(),
    name="Wave Node",
)
nodes = [fno.make_node('fno'), wave_node]

In [88]:
domain = Domain()

In [89]:
supervised = SupervisedGridConstraint(
    nodes=nodes,
    dataset=train_dataset,
    batch_size=cfg.batch_size.grid,
)
domain.add_constraint(supervised, "supervised")

In [96]:
val = GridValidator(
    nodes,
    dataset=test_dataset,
    batch_size=cfg.batch_size.validation,
    plotter=GridValidatorPlotter(),
    requires_grad=True,
)
domain.add_validator(val, "test")

In [97]:
slv = Solver(cfg, domain)

In [98]:
slv.solve()

error: <modulus.utils.io.plotter.GridValidatorPlotter object at 0x7f08a6ddc0d0>.__call__ raised an exception: index 2 is out of bounds for axis 0 with size 2


KeyboardInterrupt: 