In [None]:
import matplotlib.pyplot as plt
import logging
from autoemulate.experimental.emulators.the_well import TheWellFNO, TheWellFNOWithLearnableWeights
from pathlib import Path
from the_well.data import WellDataModule, WellDataset


In [None]:
from omegaconf import OmegaConf

In [None]:
# Make a datamodule
logging.basicConfig(level=logging.INFO)
n_steps_input = 4
n_steps_output = 1
well_dataset_name="turbulent_radiative_layer_2D"
ae_data_module = WellDataModule(
    well_base_path="../data/the_well/datasets",
    well_dataset_name=well_dataset_name,
    n_steps_input=n_steps_input,
    n_steps_output=n_steps_output,
    batch_size=4,
    train_dataset=WellDataset,
)
output_path = Path("../data/the_well/runs") / f"{well_dataset_name}_fno"

In [None]:
# Plot example
batch = next(iter(ae_data_module.val_dataloader()))
plt.imshow(batch["input_fields"][0, 0, :, :, 0])
plt.show()

In [None]:
from autoemulate.experimental.emulators.the_well import (
    DefaultChannelsFirstFormatterWithTime, TrainerParams, TheWellFNOWithTime
)
from the_well.data.data_formatter import DefaultChannelsFirstFormatter
from the_well.benchmark.metrics import VRMSE, RMSE

# from autoemulate.experimental.emulators.the_well import (
#     DefaultChannelsFirstFormatterWithTime, TrainerParams, TheWellFNOWithTime
# )
 
# Device set to MPS as example, can also be "cpu", "cuda" etc
device = "mps" # "cpu"

# Initialize the emulator
# em = TheWellFNOWithLearnableWeights(
em = TheWellFNO(
    formatter_cls=DefaultChannelsFirstFormatter,
    loss_fn=VRMSE(),
    datamodule=ae_data_module,
    trainer_params=TrainerParams(
        output_path=str(output_path),
        max_rollout_steps=100,
        device=device,
        optimizer_params={"lr": 1e-3},
    )
)


In [None]:
# Fit the model
em.fit()

In [None]:
# Validation loop
valid_results = em.trainer.validation_loop(
    ae_data_module.rollout_val_dataloader(),
    valid_or_test="rollout_valid",
    full=True
)

In [None]:
test_results = em.trainer.validation_loop(
    ae_data_module.rollout_test_dataloader(),
    valid_or_test="rollout_test",
    full=True
)

In [None]:
from pprint import pprint
pprint(valid_results)

In [None]:
pprint(test_results)