# PhysicsNEMO FNO Training with Hydra Configuration

This notebook demonstrates training an autoregressive FNO model using PhysicsNEMO with configuration management via Hydra.

In [None]:
"""https://docs.nvidia.com/physicsnemo/latest/user-guide/simple_training_example.html."""

from pathlib import Path
import physicsnemo
import torch
from physicsnemo.metrics.general.mse import mse
import logging
from the_well.data.datamodule import WellDataModule
from the_well.data.datasets import WellDataset
from the_well.benchmark.metrics import VRMSE, RMSE
from einops import rearrange
from hydra import compose, initialize
from omegaconf import OmegaConf

root_path = Path("../../autoemulate/autoemulate/experimental/")

In [None]:
# Initialize Hydra and load configuration
with initialize(version_base=None, config_path="configs"):
    cfg = compose(config_name="config")

print("Configuration loaded:")
print(OmegaConf.to_yaml(cfg))

## Load Configuration with Hydra

Load model and training configurations from YAML files using Hydra.

In [None]:
# Make a datamodule using Hydra config
logging.basicConfig(level=logging.INFO)

ae_data_module = WellDataModule(
    well_base_path=str(root_path / "exploratory/data/the_well/datasets"),
    well_dataset_name=cfg.dataset.name,
    n_steps_input=cfg.dataset.n_steps_input,
    n_steps_output=cfg.dataset.n_steps_output,
    batch_size=cfg.dataset.batch_size,
    train_dataset=WellDataset,
)

output_path = cfg.output_path

## Setup Data Module

Create the data module using configuration parameters.

In [None]:
import torch
torch.cuda.is_available()


In [None]:
dataloader = ae_data_module.train_dataloader()
dataloader_iter = iter(dataloader)
batch = next(dataloader_iter)

In [None]:
_, n_time_steps, height, width, n_channels = batch["input_fields"].shape

In [None]:
from spatio_temporal_forecasting.AR_FNO import AutoregressiveFNO
from spatio_temporal_forecasting.fno_emulator import MultivariableFNO

# Build FNO model from Hydra config
device = cfg.device
fno_base = MultivariableFNO(
    n_vars=cfg.model.n_vars,
    n_modes=tuple(cfg.model.n_modes),
    hidden_channels=cfg.model.hidden_channels,
    n_layers=cfg.model.n_layers,
    use_skip_connections=cfg.model.use_skip_connections
)
model = AutoregressiveFNO(
    fno_model=fno_base, 
    t_in=cfg.model.t_in, 
    t_out=cfg.model.t_out
).to(device)

dataloader = ae_data_module.train_dataloader()

In [None]:
batch = next(iter(dataloader))

x = batch["input_fields"]
x = x[..., :1] # only first channel
x = rearrange(x, "b t h w c -> b t c h w")
# model(x).shape
# x.shape

In [None]:
from physicsnemo.launch.logging import LaunchLogger, PythonLogger

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer, lr_lambda=lambda step: 0.85**step
)

# Initialize the logger
logger = PythonLogger("main")  # General python logger
LaunchLogger.initialize()

# Use logger methods to track various information during training
logger.info("Starting Training!")
for epoch in range(2):
    with LaunchLogger("train", epoch=epoch) as launchlog:
        for batch_idx, batch in enumerate(ae_data_module.train_dataloader()):
            optimizer.zero_grad()
            y_true = batch["output_fields"].to(device)
            y_true = y_true[..., :1] # only first channel
            x = batch["input_fields"].to(device)
            x = x[..., :1] # only first channel
            x = rearrange(x, "b t h w c -> b t c h w")
            y_pred = model(x)
            y_pred = rearrange(y_pred, "b t c h w -> b t h w c")
            loss = mse(y_pred, y_true)
            loss.backward()
            optimizer.step()
            scheduler.step()

            launchlog.log_minibatch({"Loss": loss.detach().cpu().numpy()})

        launchlog.log_epoch({"Learning Rate": optimizer.param_groups[0]["lr"]})
logger.info("Finished Training!")

In [None]:
import torch

torch.__version__