In [None]:
from pathlib import Path
from the_well.data import WellDataModule
import matplotlib.pyplot as plt
import logging
from autoemulate.experimental.emulators.the_well import TheWellFNO

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(name)s %(levelname)s: %(message)s",
)

In [None]:
# Make an autoemulate datamodule from the_well datamodule
from autoemulate.simulations.advection_diffusion import AdvectionDiffusion
rd = AdvectionDiffusion(n=64, T=10, dt=0.1, return_timeseries=True)
data = rd.forward_samples_spatiotemporal(6)
y = data["data"]
data_valid = rd.forward_samples_spatiotemporal(2)
data_test = rd.forward_samples_spatiotemporal(2)

In [None]:
from autoemulate.experimental.data.spatiotemporal_dataset import AdvectionDiffusionDataset, AutoEmulateDataModule

ae_data_module = AutoEmulateDataModule(
    n_steps_input=4,
    n_steps_output=1,
    data_path=None,
    dataset_cls=AdvectionDiffusionDataset,
    data={"train": data, "valid": data_valid, "test": data_test},
    verbose=False
)
output_path = "../data/the_well/runs/advection_diffusion_wip"

In [None]:
# Plot example
batch = next(iter(ae_data_module.val_dataloader()))
plt.imshow(batch["input_fields"][0, 0, :, :, 0])
plt.show()

In [None]:
# Initialize the emulator
from the_well.data.data_formatter import DefaultChannelsFirstFormatter
from the_well.benchmark.metrics import VRMSE
from autoemulate.experimental.emulators.the_well import TheWellAFNO, TrainerParams

em = TheWellFNO(
    datamodule=ae_data_module,
    formatter_cls=DefaultChannelsFirstFormatter,
    loss_fn=VRMSE(),
    trainer_params=TrainerParams(
        device="mps",
        output_path=output_path,
        max_rollout_steps=100,
        optimizer_params={"lr": 1e-3}
    )
)

In [None]:
# Fit the model
em.fit()

In [None]:
# Validation loop
valid_results = em.trainer.validation_loop(
    ae_data_module.rollout_val_dataloader(),
    valid_or_test="rollout_valid",
    full=True
)

In [None]:
test_results = em.trainer.validation_loop(
    ae_data_module.rollout_test_dataloader(),
    valid_or_test="rollout_test",
    full=True
)

In [None]:
from pprint import pprint
pprint(valid_results)

In [None]:
pprint(test_results)

In [None]:
# Run prediction from a dataloader
em.predict(ae_data_module.rollout_test_dataloader()).shape

In [None]:
# Run prediction from a non-rollout dataloader
em.predict(ae_data_module.test_dataloader()).shape


In [None]:
# Initialize a UNet emulator
from autoemulate.experimental.emulators.the_well import TheWellUNetClassic


em = TheWellUNetClassic(datamodule=ae_data_module, output_path=output_path, device="cpu")

In [None]:
em.fit()

In [None]:
em.predict(ae_data_module.rollout_test_dataloader()).shape