In [None]:
from the_well.data import WellDataset, DeltaWellDataset

from pathlib import Path
from the_well.data import WellDataModule
import matplotlib.pyplot as plt
import logging
from autoemulate.experimental.emulators.the_well import TheWellFNO

logger = logging.getLogger()

In [None]:
# Make an autoemulate datamodule from the_well datamodule
from autoemulate.simulations.reaction_diffusion import ReactionDiffusion
rd = ReactionDiffusion(n=32, T=10, dt=0.1, return_timeseries=True)
data = rd.forward_samples_spatiotemporal(3)
y = data["data"]


In [None]:
data_valid = rd.forward_samples_spatiotemporal(1)
data_test = rd.forward_samples_spatiotemporal(1)

In [None]:
from autoemulate.experimental.data.spatiotemporal_dataset import AutoEmulateDataModule, ReactionDiffusionDataset

ae_data_module = AutoEmulateDataModule(
    data_path=None,
    dataset_cls=ReactionDiffusionDataset,
    data={"train": data, "valid": data_valid, "test": data_test},
    verbose=False
)
output_path = Path("../data/the_well/runs")

In [None]:
# Plot example
batch = next(iter(ae_data_module.val_dataloader()))
plt.imshow(batch["input_fields"][0, 0, :, :, 0])
plt.show()

In [None]:
# Initialize the emulator
em = TheWellFNO(datamodule=ae_data_module, output_path=output_path, device="cpu")

In [None]:
# Fit the model
em.fit()

In [None]:
# Validation loop
valid_results = em.trainer.validation_loop(
    ae_data_module.rollout_val_dataloader(),
    valid_or_test="rollout_valid",
    full=True
)

In [None]:
test_results = em.trainer.validation_loop(
    ae_data_module.rollout_test_dataloader(),
    valid_or_test="rollout_test",
    full=True
)

In [None]:
from pprint import pprint
pprint(valid_results)

In [None]:
pprint(test_results)

In [None]:
# Run prediction from a dataloader
em.predict(ae_data_module.rollout_test_dataloader()).shape

In [None]:
# Run prediction from a non-rollout dataloader
em.predict(ae_data_module.test_dataloader()).shape