In [None]:
import matplotlib.pyplot as plt
import logging
from autoemulate.experimental.emulators.the_well import TheWellFNO
from the_well.benchmark.metrics.plottable_data import make_video

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(name)s %(levelname)s: %(message)s",
)

In [None]:
from autoemulate.experimental.data.spatiotemporal_dataset import AutoEmulateDataModule, BOUTDataset

# Note this assumes that the BOUT++ data is stored in ../data/bout/ with the structure:
# ../data/bout/train/data.pt
# ../data/bout/valid/data.pt
# ../data/bout/test/data.pt
ae_data_module = AutoEmulateDataModule(
    n_steps_input=5,
    n_steps_output=5,
    data_path="../data/bout/",
    dataset_cls=BOUTDataset,
    verbose=False
)
output_path = "../data/the_well/runs/bout_wip_new"

In [None]:
from autoemulate.experimental.data.spatiotemporal_dataset import AutoEmulateDataset
from torch.utils.data import DataLoader

ds = BOUTDataset(
    data_path="../data/bout/train/data.pt",
    n_steps_input=51,
    n_steps_output=0
)
# dl = DataLoader(ds, shuffle=False)


In [None]:
# # Create a function to animate simulation images
# it = iter(ds)
# traj = next(it)["input_fields"]
# _ = make_video(
#     traj,
#     traj,
#     output_dir=output_path,
#     metadata=ds.metadata,
# )

In [None]:
# Plot example
batch = next(iter(ae_data_module.val_dataloader()))
plt.imshow(batch["input_fields"][0, 0, :, :, 0])
plt.show()

In [None]:
# Initialize the emulator
from the_well.benchmark.metrics import VRMSE, RMSE
from the_well.data.data_formatter import DefaultChannelsFirstFormatter, DefaultChannelsLastFormatter
from autoemulate.experimental.emulators.the_well import TheWellUNetClassic, TheWellUNetConvNext, TrainerParams


# em = TheWellUNetConvNext(
em = TheWellFNO(
    formatter_cls=DefaultChannelsFirstFormatter,
    datamodule=ae_data_module,
    loss_fn=VRMSE(),
    # loss_fn=RMSE(),
    trainer_params=TrainerParams(
        max_rollout_steps=100,
        output_path=output_path,
        device="mps",
        optimizer_params={"lr": 1e-4},
        enable_tf_schedule=True

    ),
)

In [None]:
# Fit the model
em.fit()

In [None]:
# Validation loop
valid_results = em.trainer.validation_loop(
    ae_data_module.rollout_val_dataloader(),
    valid_or_test="rollout_valid",
    full=True
)

In [None]:
test_results = em.trainer.validation_loop(
    ae_data_module.rollout_test_dataloader(),
    valid_or_test="rollout_test",
    full=True
)

In [None]:
from pprint import pprint
pprint(valid_results)

In [None]:
pprint(test_results)

In [None]:
# Run prediction from a dataloader
em.predict(ae_data_module.rollout_test_dataloader()).shape

In [None]:
# Run prediction from a non-rollout dataloader
em.predict(ae_data_module.test_dataloader()).shape
