## AutoCast encoder-processor-decoder model API Exploration

This notebook aims to explore the end-to-end API.


### Example dataaset

We use the `AdvectionDiffusion` dataset as an example dataset to illustrate training and evaluation of models. This dataset simulates the advection-diffusion equation in 2D.


In [3]:
from autoemulate.simulations.advection_diffusion import AdvectionDiffusion as Sim

sim = Sim(return_timeseries=True, log_level="error")


def generate_split(simulator: Sim, n_train: int = 10, n_valid: int = 2, n_test: int = 2):
    """Generate training, validation, and test splits from the simulator."""
    train = simulator.forward_samples_spatiotemporal(n_train)
    valid = simulator.forward_samples_spatiotemporal(n_valid)
    test = simulator.forward_samples_spatiotemporal(n_test)
    return {"train": train, "valid": valid, "test": test}


combined_data = generate_split(sim)

### Read combined data into datamodule


In [4]:
from auto_cast.data.datamodule import SpatioTemporalDataModule

n_steps_input = 4
n_steps_output = 1
datamodule = SpatioTemporalDataModule(
    data=combined_data,
    data_path=None,
    n_steps_input=n_steps_input,
    n_steps_output=n_steps_output,
    batch_size=16,
)

### Example batch


In [5]:
batch = next(iter(datamodule.train_dataloader()))

# batch

In [6]:
from auto_cast.decoders.channels_last import ChannelsLast
from auto_cast.encoders.permute_concat import PermuteConcat
from auto_cast.models.encoder_decoder import EncoderDecoder
from auto_cast.models.encoder_processor_decoder import EncoderProcessorDecoder
from auto_cast.nn.fno import FNOProcessor

batch = next(iter(datamodule.train_dataloader()))
n_channels = batch.input_fields.shape[-1]
processor = FNOProcessor(
    in_channels=n_channels * n_steps_input,
    out_channels=n_channels * n_steps_output,
    n_modes=(16, 16),
    hidden_channels=64,
)
encoder = PermuteConcat(with_constants=False)
decoder = ChannelsLast(output_channels=n_channels, time_steps=n_steps_output)

model = EncoderProcessorDecoder.from_encoder_processor_decoder(
    encoder_decoder=EncoderDecoder.from_encoder_decoder(
        encoder=encoder, decoder=decoder
    ),
    processor=processor,
)

In [7]:
model(batch).shape

torch.Size([16, 1, 50, 50, 1])

### Run trainer


In [15]:
import lightning as L

device = "mps"  # "cpu"
# device = "cpu"
trainer = L.Trainer(max_epochs=10, accelerator=device, log_every_n_steps=10)
trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())

ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores

  | Name            | Type           | Params | Mode 
-----------------------------------------------------------
0 | loss_func       | MSELoss        | 0      | train
1 | encoder_decoder | EncoderDecoder | 0      | train
2 | processor       | FNOProcessor   | 2.4 M  | train
-----------------------------------------------------------
2.4 M     Trainable params
0         Non-trainable params
2.4 M     Total params
9.642     Total estimated model params size (MB)
57        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


### Run the evaluation


In [9]:
trainer.test(model, datamodule.test_dataloader())

Testing: |          | 0/? [00:00<?, ?it/s]

â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
       Test metric             DataLoader 0
â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
        test_loss          0.0004150373861193657
â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”

[{'test_loss': 0.0004150373861193657}]

### Example rollout


In [10]:
# A single element is the full trajectory
batch = next(iter(datamodule.rollout_test_dataloader()))

In [11]:
# First n_steps_input are inputs
print(batch.input_fields.shape)
# Remaining n_steps_output are outputs
print(batch.output_fields.shape)

torch.Size([1, 4, 50, 50, 1])
torch.Size([1, 316, 50, 50, 1])


In [12]:
# Run rollout on one trajectory
preds, trues = model.rollout(batch, free_running_only=True)

In [22]:
assert preds.shape == trues.shape
mse_error = MSE()(preds, trues, trues)

AttributeError: 'Tensor' object has no attribute 'n_spatial_dims'

In [13]:
print(preds.shape)

torch.Size([1, 10, 1, 50, 50, 1])


In [14]:
assert trues is not None
print(trues.shape)


torch.Size([1, 10, 1, 50, 50, 1])


In [27]:
from the_well.benchmark.metrics import RMSE, MAE, MSE
from the_well.data.datasets import WellMetadata

rmse_error = RMSE.eval(preds, trues, WellMetadata)

AttributeError: type object 'WellMetadata' has no attribute 'n_spatial_dims'

In [25]:
WellMetadata.n_spatial_dims

AttributeError: type object 'WellMetadata' has no attribute 'n_spatial_dims'