In [None]:
from the_well.data import WellDataset, DeltaWellDataset
from torchinfo import summary

from pathlib import Path
from omegaconf import DictConfig
from the_well.benchmark.trainer import Trainer
from the_well.data.data_formatter import DefaultChannelsFirstFormatter, DefaultChannelsLastFormatter
from the_well.benchmark import models
import torch
from the_well.data import WellDataModule
from the_well.benchmark.utils.experiment_utils import configure_experiment
import matplotlib.pyplot as plt
import logging
from the_well.benchmark.metrics import MSE, VRMSE

logger = logging.getLogger()



In [None]:
# checkpoint_folder: str,
# artifact_folder: str,
# viz_folder: str,
# formatter: str,
# model: torch.nn.Module,
# datamodule: AbstractDataModule,
# optimizer: torch.optim.Optimizer,
# loss_fn: Callable,
# # validation_suite: list,
# epochs: int,
# checkpoint_frequency: int,
# val_frequency: int,
# rollout_val_frequency: int,
# max_rollout_steps: int,
# short_validation_length: int,
# make_rollout_videos: bool,
# num_time_intervals: int,
# lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
# device=torch.device("cuda"),
# is_distributed: bool = False,
# enable_amp: bool = False,
# amp_type: str = "float16",  # bfloat not supported in FFT
# checkpoint_path: str = "",


n_steps_input = 1
n_steps_output = 1
well_dataset_name="turbulent_radiative_layer_2D"
the_well_ds = WellDataset(
    well_base_path="../data/the_well/datasets",
    well_dataset_name=well_dataset_name,
    well_split_name="test",
    n_steps_input=n_steps_input,
    n_steps_output=n_steps_output,
    
)

well_base_path_run = Path("../data/the_well/runs") / f"{well_dataset_name}_in{n_steps_input}_out{n_steps_output}"
checkpoint_path=well_base_path_run / "checkpoints"
artifact_path=well_base_path_run / "artifacts"

# Make a datamodule
datamodule = WellDataModule(
    well_base_path="../data/the_well/datasets",
    well_dataset_name=well_dataset_name,
    n_steps_input=n_steps_input,
    n_steps_output=n_steps_output,
    batch_size=4,
    train_dataset=WellDataset,
)

dset_metadata = the_well_ds.metadata
n_spatial_dims=dset_metadata.n_spatial_dims,
spatial_resolution=dset_metadata.spatial_resolution,
n_input_fields = (
    n_steps_input * dset_metadata.n_fields
    + dset_metadata.n_constant_fields
)
n_output_fields = dset_metadata.n_fields
dim_in=n_input_fields,
dim_out=n_output_fields,

model = models.FNO(
    modes1=16,
    modes2=16,
    dim_in=n_input_fields,
    dim_out=n_output_fields,
    n_spatial_dims=dset_metadata.n_spatial_dims,
    spatial_resolution=dset_metadata.spatial_resolution,
)
print(summary(model, depth=5))

device = torch.device("cpu")
optimizer: torch.optim.Optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
lr_scheduler = None

trainer = Trainer(
    checkpoint_folder=str(checkpoint_path),
    artifact_folder=str(artifact_path),
    viz_folder=str(artifact_path / "viz"),
    formatter="channels_first_default", # "channels_last_default" for other models
    model=model,
    datamodule=datamodule,
    optimizer=optimizer,
    # loss_fn=MSE(),
    loss_fn=VRMSE(),
    epochs=10,
    checkpoint_frequency=5,
    val_frequency=1,
    rollout_val_frequency=5,
    max_rollout_steps=10,
    short_validation_length=20,
    make_rollout_videos=False,
    num_time_intervals=n_steps_output,
    lr_scheduler=lr_scheduler,
    device=device,
    is_distributed=False,
    enable_amp=False,
    amp_type="float16",  # bfloat not supported in FFT
    checkpoint_path="",  # Path to a checkpoint to resume from, if any
)

In [None]:
# Plot example
batch = next(iter(datamodule.val_dataloader()))
plt.imshow(batch["input_fields"][0, 0, :, :, 0])
plt.show()

In [None]:
# Trainer (this takes a long time)
trainer.train()


In [None]:
# Predictions with the rollout model from a batch
batch = next(iter(datamodule.val_dataloader()))
trainer.rollout_model(trainer.model, batch, trainer.formatter, train=False)[0].shape