## AutoCast encoder-processor-decoder model API Exploration

This notebook aims to explore the end-to-end API.


### Example dataaset

We use the `AdvectionDiffusion` dataset as an example dataset to illustrate training and evaluation of models. This dataset simulates the advection-diffusion equation in 2D.


In [1]:
from autoemulate.simulations.advection_diffusion import AdvectionDiffusion as Sim

sim = Sim(return_timeseries=True, log_level="error")


def generate_split(simulator: Sim, n_train: int = 10, n_valid: int = 2, n_test: int = 2):
    """Generate training, validation, and test splits from the simulator."""
    train = simulator.forward_samples_spatiotemporal(n_train)
    valid = simulator.forward_samples_spatiotemporal(n_valid)
    test = simulator.forward_samples_spatiotemporal(n_test)
    return {"train": train, "valid": valid, "test": test}


combined_data = generate_split(sim)

  """


### Read combined data into datamodule


In [4]:
from auto_cast.data.datamodule import SpatioTemporalDataModule

n_steps_input = 1
n_steps_output = 4
datamodule = SpatioTemporalDataModule(
    data=combined_data,
    data_path=None,
    n_steps_input=n_steps_input,
    n_steps_output=n_steps_output,
    batch_size=16,
)

### Example batch


In [7]:
batch = next(iter(datamodule.train_dataloader()))

batch.input_fields.shape, batch.output_fields.shape

(torch.Size([16, 1, 50, 50, 1]), torch.Size([16, 4, 50, 50, 1]))

In [8]:
import torch
from einops import rearrange

from auto_cast.encoders.base import Encoder
from auto_cast.types import Batch, Tensor, TensorBCWH


class IdentityEncoder(Encoder):
    """Permute and concatenate Encoder."""

    def __init__(self) -> None:
        super().__init__()
    def forward(self, batch: Batch) -> Tensor:
        return batch.input_fields

    def encode(self, batch: Batch) -> TensorBCWH:
        return self.forward(batch)

In [10]:
from einops import rearrange

from auto_cast.decoders.base import Decoder
from auto_cast.types import TensorBCTSPlus, TensorBMStarL, TensorBTSPlusC


class IdentityDecoder(Decoder):
    def __init__(self) -> None:
        super().__init__()
    def forward(self, x: TensorBCTSPlus) -> TensorBTSPlusC:
        return x

    def decode(self, z: TensorBTSPlusC) -> TensorBTSPlusC:
        return self.forward(z)


In [None]:
import torch
import torch.nn as nn
from azula.nn.unet import UNet
from azula.nn.embedding import SineEncoding

class TemporalUNetBackbone(nn.Module):
    """Azula UNet with proper time embedding."""
    
    def __init__(
        self,
        in_channels: int = 1,
        out_channels: int = 1,
        cond_channels: int = 1,
        mod_features: int = 256,
        hid_channels: tuple = (32, 64, 128),
        hid_blocks: tuple = (2, 2, 2),
        spatial: int = 2,
        periodic: bool = False,
    ):
        super().__init__()
        
        # Time embedding
        self.time_embedding = nn.Sequential(
            SineEncoding(mod_features),
            nn.Linear(mod_features, mod_features),
            nn.SiLU(),
            nn.Linear(mod_features, mod_features),
        )
        
        self.unet = UNet(
            in_channels=in_channels + cond_channels,
            out_channels=out_channels,
            cond_channels=0,
            mod_features=mod_features,
            hid_channels=hid_channels,
            hid_blocks=hid_blocks,
            kernel_size=3,
            stride=2,
            spatial=spatial,
            periodic=periodic,
        )

    def forward(self, x_out: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x_out: Noisy data (B, T, C, H, W) - channels first from Azula
            t: Time steps (B,)
            cond: Conditioning input (B, T_cond, C, H, W) - channels first
        Returns:
            Denoised output (B, T, C, H, W)
        """
        B, T, W, H, C = x_out.shape
        _, T_cond, W_cond, H_cond , C_cond = cond.shape
        assert W == W_cond and H == H_cond
        print("x_out.shape", x_out.shape)
        print("cond.shape", cond.shape)
        # Embed time (once per batch)
        t_emb = self.time_embedding(t)  # (B, mod_features)
        mod_for_unet = t_emb
        print(t_emb.shape)
        t_emb = rearrange(t_emb, "b m -> b  1 1 1 m")
        t_emb = t_emb.expand(B, T_cond, W, H, -1)  # (B, mod_features, H, W)

        print("t_emb.shape", t_emb.shape)
        # Concatenate along channel dimension
        x_cond = torch.cat([cond, t_emb], dim=-1)  # (B, T, C+C_cond, H, W)
        print("x_cond.shape", x_cond.shape)
        
        x_cond = rearrange(x_cond, "b t w h c -> b (t c) w h")
        print("x_cond reshaped", x_cond.shape)
        # Process through UNet
        out_flat = self.unet(x_cond, mod=mod_for_unet)
        print("out",out_flat.shape)
        # Reshape back to (B, T, C, H, W)
        return out_flat.reshape(B, T, W, H, C)


In [79]:
from auto_cast.decoders.channels_last import ChannelsLast
from auto_cast.encoders.permute_concat import PermuteConcat
from auto_cast.models.encoder_decoder import EncoderDecoder
from auto_cast.models.encoder_processor_decoder import EncoderProcessorDecoder
from auto_cast.processors.diffusion import DiffusionProcessor
from azula.noise import CosineSchedule

batch = next(iter(datamodule.train_dataloader()))
n_channels = batch.input_fields.shape[-1]
# Create schedule
schedule = CosineSchedule()

backbone = TemporalUNetBackbone(
    in_channels=(n_channels+128)*n_steps_input,          # 1
    out_channels=n_channels,         # 1
    cond_channels=0,        # 1
    mod_features=128,
    hid_channels=(16, 32, 64),
    hid_blocks=(2, 2, 2),
    spatial=2,
    periodic=False,
)


processor = DiffusionProcessor(
    backbone=backbone,
    schedule=schedule,
    denoiser_type='karras',
    learning_rate=1e-4,
    n_steps_output=n_steps_output,  # 4
    stride=1,
    max_rollout_steps=10,
    teacher_forcing_ratio=0.0,
)
encoder = IdentityEncoder()
decoder = IdentityDecoder()

model = EncoderProcessorDecoder.from_encoder_processor_decoder(
    encoder_decoder=EncoderDecoder.from_encoder_decoder(
        encoder=encoder, decoder=decoder
    ),
    processor=processor,
)

In [80]:
model(batch).shape

x_out.shape torch.Size([16, 4, 50, 50, 1])
cond.shape torch.Size([16, 1, 50, 50, 1])
torch.Size([16, 128])
t_emb.shape torch.Size([16, 1, 50, 50, 128])
x_cond.shape torch.Size([16, 1, 50, 50, 129])
x_cond reshaped torch.Size([16, 129, 50, 50])
out torch.Size([16, 1, 50, 50])


RuntimeError: shape '[16, 4, 50, 50, 1]' is invalid for input of size 40000

### Run trainer


In [None]:
import lightning as L

device = "mps"  # "cpu"
# device = "cpu"
trainer = L.Trainer(max_epochs=1, accelerator=device, log_every_n_steps=10)
trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())

### Run the evaluation


In [None]:
trainer.test(model, datamodule.test_dataloader())

### Example rollout


In [None]:
# A single element is the full trajectory
batch = next(iter(datamodule.rollout_test_dataloader()))

In [None]:
# First n_steps_input are inputs
print(batch.input_fields.shape)
# Remaining n_steps_output are outputs
print(batch.output_fields.shape)

In [None]:
# Run rollout on one trajectory
preds, trues = model.rollout(batch, free_running_only=True)

In [None]:
print(preds.shape)

In [None]:
assert trues is not None
print(trues.shape)
