In [None]:
"""https://docs.nvidia.com/physicsnemo/latest/user-guide/simple_training_example.html."""

from pathlib import Path
import physicsnemo
import torch
# from physicsnemo.datapipes.benchmarks.darcy import Darcy2D
from physicsnemo.metrics.general.mse import mse
# from physicsnemo.models.fno.fno import FNO
import logging
from the_well.data.datamodule import WellDataModule
from the_well.data.datasets import WellDataset
from the_well.benchmark.metrics import VRMSE, RMSE
from einops import rearrange

# Make a datamodule
logging.basicConfig(level=logging.INFO)
n_steps_input = 4
n_steps_output = 1
well_dataset_name="turbulent_radiative_layer_2D"
ae_data_module = WellDataModule(
    well_base_path="./exploratory/data/the_well/datasets",
    well_dataset_name=well_dataset_name,
    n_steps_input=n_steps_input,
    n_steps_output=n_steps_output,
    batch_size=4,
    train_dataset=WellDataset,
)
output_path = Path("./exploratory/data/the_well/runs") / f"{well_dataset_name}_fno_physicsnemo"

In [None]:
# class PhysicsNemoFNOModel(physicsnemo.models.fno.fno.FNO):
from einops import rearrange
from torch import nn
from neuralop.models.fno import FNO
from dataclasses import dataclass
from physicsnemo.models.meta import ModelMetaData
from physicsnemo.models.module import Module

@dataclass
class MetaData(ModelMetaData):
    name: str = "FNONemo"
    # Optimization
    jit: bool = False
    cuda_graphs: bool = True
    amp_cpu: bool = True
    amp_gpu: bool = True

class FNONemo(Module):
    nn: FNO

    def __init__(self, *args, **kwargs):
        super(FNONemo, self).__init__(meta=MetaData())
        self.nn = FNO(*args, **kwargs)
    def forward(self, x):
        # TODO: explore normalization options
        x = rearrange(x, "b t h w c -> b (t c) h w")
        y_pred = self.nn(x)
        y_pred = rearrange(y_pred, "b c h w -> b 1 h w c")
        return y_pred

# MyFNOModelFromTorch = Module.from_torch(MyFNO, meta=MetaData)


In [None]:
dataloader = ae_data_module.train_dataloader()


In [None]:
dataloader = ae_data_module.train_dataloader()
dataloader_iter = iter(dataloader)
batch = next(dataloader_iter)

In [None]:
_, n_time_steps, height, width, n_channels = batch["input_fields"].shape

In [None]:
dataloader = ae_data_module.train_dataloader()
model = FNONemo(
    n_modes=(16, 16),
    hidden_channels=16,
    in_channels=n_steps_input * n_channels,
    out_channels=n_steps_output * n_channels,
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer, lr_lambda=lambda step: 0.85**step
)

for epoch in range(20):
    for batch_idx, batch in enumerate(ae_data_module.train_dataloader()):
        optimizer.zero_grad()
        y_true = batch["output_fields"]
        x = batch["input_fields"]
        y_pred = model(x)
        loss = mse(y_pred, y_true)
        loss.backward()
        optimizer.step()
        scheduler.step()

        print(f"Epoch: {epoch}, Iteration: {batch_idx}. Loss: {loss.detach().cpu().numpy()}")

In [None]:
# https://docs.nvidia.com/physicsnemo/latest/user-guide/simple_training_example.html#running-inference-on-trained-models


# Save the checkpoint. For demo, we will just save untrained checkpoint,
# but in typical workflows is saved after model training.
model.save("untrained_checkpoint.mdlus")

# Inference code

# The parameters to instantitate the model will be loaded from the checkpoint
model_inf = physicsnemo.Module.from_checkpoint("untrained_checkpoint.mdlus")


In [None]:
# Crashing here
batch = next(iter(ae_data_module.val_dataloader()))

In [None]:

# put the model in evaluation mode
model_inf.eval()

# run inference
with torch.inference_mode():
    batch = next(iter(ae_data_module.test_dataloader()))
    output = model_inf(batch["input_fields"])
