In [None]:
"""https://docs.nvidia.com/physicsnemo/latest/user-guide/simple_training_example.html."""

from pathlib import Path
import physicsnemo
import torch
# from physicsnemo.datapipes.benchmarks.darcy import Darcy2D
from physicsnemo.metrics.general.mse import mse
# from physicsnemo.models.fno.fno import FNO
import logging
from the_well.data.datamodule import WellDataModule
from the_well.data.datasets import WellDataset
from the_well.benchmark.metrics import VRMSE, RMSE
from einops import rearrange

root_path = Path("../../autoemulate/autoemulate/experimental/")

# Make a datamodule
logging.basicConfig(level=logging.INFO)
n_steps_input = 4
# n_steps_input = 1
n_steps_output = 1
well_dataset_name="turbulent_radiative_layer_2D"
ae_data_module = WellDataModule(
    well_base_path=str(root_path / "exploratory/data/the_well/datasets"),
    well_dataset_name=well_dataset_name,
    n_steps_input=n_steps_input,
    n_steps_output=n_steps_output,
    batch_size=4,
    train_dataset=WellDataset,
)
# output_path = root_path / Path("exploratory/data/the_well/runs") / f"{well_dataset_name}_fno_physicsnemo"
output_path = f"{well_dataset_name}_fno_physicsnemo"

In [None]:
import torch
torch.cuda.is_available()


In [None]:

dataloader = ae_data_module.train_dataloader()
dataloader_iter = iter(dataloader)
batch = next(dataloader_iter)

In [None]:
_, n_time_steps, height, width, n_channels = batch["input_fields"].shape

In [None]:
from spatio_temporal_forecasting.AR_FNO import AutoregressiveFNO
from spatio_temporal_forecasting.fno_emulator import MultivariableFNO

device = "cuda"
fno_base = MultivariableFNO(
    n_vars=1,
    n_modes=(16, 16),
    hidden_channels=16,
    n_layers=4,
    use_skip_connections=False
)
model = AutoregressiveFNO(fno_model=fno_base, t_in=4, t_out=1).to(device)

dataloader = ae_data_module.train_dataloader()

In [None]:
batch = next(iter(dataloader))

x = batch["input_fields"]
x = x[..., :1] # only first channel
x = rearrange(x, "b t h w c -> b t c h w")
# model(x).shape
# x.shape

In [None]:
from physicsnemo.launch.logging import LaunchLogger, PythonLogger

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer, lr_lambda=lambda step: 0.85**step
)

# Initialize the logger
logger = PythonLogger("main")  # General python logger
LaunchLogger.initialize()

# Use logger methods to track various information during training
logger.info("Starting Training!")
for epoch in range(2):
    with LaunchLogger("train", epoch=epoch) as launchlog:
        for batch_idx, batch in enumerate(ae_data_module.train_dataloader()):
            optimizer.zero_grad()
            y_true = batch["output_fields"].to(device)
            y_true = y_true[..., :1] # only first channel
            x = batch["input_fields"].to(device)
            x = x[..., :1] # only first channel
            x = rearrange(x, "b t h w c -> b t c h w")
            y_pred = model(x)
            y_pred = rearrange(y_pred, "b t c h w -> b t h w c")
            loss = mse(y_pred, y_true)
            loss.backward()
            optimizer.step()
            scheduler.step()

            launchlog.log_minibatch({"Loss": loss.detach().cpu().numpy()})

        launchlog.log_epoch({"Learning Rate": optimizer.param_groups[0]["lr"]})
logger.info("Finished Training!")

In [None]:
import torch

torch.__version__