## AutoCast Processor Evaluation

This notebook evaluates a pre-trained processor model on the MiniWell dataset.
It loads the model configuration and weights from a specified run directory.

In [None]:
import os

import hydra
import lightning as L
import matplotlib.pyplot as plt
import torch
from hydra.utils import instantiate
from IPython.display import HTML
from omegaconf import OmegaConf

from autocast.external.lola.lola_autoencoder import get_autoencoder
from autocast.models.processor import ProcessorModel
from autocast.utils.plots import plot_spatiotemporal_video


In [None]:
# Path to the run directory
run_path = "../outputs/rayleigh_benard/2026-01-14_diffusion_vit_small"
config_path = os.path.join(run_path, "resolved_processor_config.yaml")
ckpt_path = os.path.join(run_path, "processor.ckpt")

# Load configuration
cfg = OmegaConf.load(config_path)
# print(OmegaConf.to_yaml(cfg))

In [None]:
# Instantiate DataModule and setup
datamodule = instantiate(cfg.data)
datamodule.setup() # Setup all stages (fit for train/val, test for test)

In [None]:
# Instantiate Processor
processor = instantiate(cfg.model.processor)

# Instantiate ProcessorModelWrapper
model = ProcessorModel(
    processor=processor,
    learning_rate=cfg.model.learning_rate,
)

In [None]:
# Load checkpoint
checkpoint = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(checkpoint['state_dict'])
model.eval()
print("Model loaded successfully")

In [None]:
# model

In [None]:
# Run test set evaluation

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
trainer = L.Trainer(accelerator=device, logger=False)
# trainer.test(model, datamodule=datamodule)

In [None]:
# Get a batch for visualization from validation set
dl = datamodule.val_dataloader()
batch = next(iter(dl))

# Move to device
# batch = batch.to(device)
# model = model.to(device)

print(f"Batch shape inputs: {batch.encoded_inputs.shape}")

In [None]:
# Create rollout
with torch.no_grad():
    preds, trues = [], []
    for i, batch in enumerate(datamodule.val_dataloader()):
        pred = model(batch.encoded_inputs, batch.global_cond)
        preds.append(pred)
        trues.append(batch.encoded_output_fields)
        if i >= 5:  # Limit to 5 batches for speed
            break
    preds = torch.cat(preds, dim=0)
    trues = torch.cat(trues, dim=0)
print(f"Predictions shape: {preds.shape}")
print(f"Ground Truth shape: {trues.shape}")

### Decoded Evaluation

Load the corresponding AutoEncoder to decode the latent predictions back to pixel space and visualization.

In [None]:
# # Plot latent space predictions
# anim = plot_spatiotemporal_video(
#     true=trues,
#     pred=preds,
#     batch_idx=0,
#     save_path=None,
#     title="Latent Space Prediction",
#     colorbar_mode="row",
# )
# HTML(anim.to_jshtml())

### Latent Space Evaluation

Visualize the predictions in the latent space (before decoding).

In [None]:
torch.arange(0, preds.shape[0]-1, 4)

In [None]:
from einops import rearrange

preds_plot = rearrange(preds[::4, ...], "b t ... -> 1 (b t) ...")
trues_plot = rearrange(trues[::4, ...], "b t ... -> 1 (b t) ...")

In [None]:
# Plot decoded
anim = plot_spatiotemporal_video(
    true=trues_plot[..., :4],
    pred=preds_plot[..., :4],
    batch_idx=0,
    save_path=None,
    title="Latent Prediction",
    colorbar_mode="row",
)
HTML(anim.to_jshtml())

In [None]:
# Load AutoEncoder to decode predictions
ae_path = "../datasets/rayleigh_benard/1e3z5x2c_rayleigh_benard_dcae_f32c64_large"
ae_config_path = os.path.join(ae_path, "config.yaml")
ae_ckpt_path = os.path.join(ae_path, "state.pth")

print(f"Loading AutoEncoder from: {ae_path}")
ae_cfg = OmegaConf.load(ae_config_path)

# Convert to dictionary to avoid OmegaConf/beartype conflicts for most args (like attention_heads)
ae_config_dict = OmegaConf.to_container(ae_cfg.ae, resolve=True)

# However, get_autoencoder specifically types 'loss' as DictConfig, so we must preserve it
if "loss" in ae_cfg.ae:
    ae_config_dict["loss"] = ae_cfg.ae.loss

# Instantiate AutoEncoder
# We pass **ae_config_dict to unpack arguments
autoencoder = get_autoencoder(**ae_config_dict)

In [None]:
preds.shape

In [None]:
preds.shape

In [None]:
# Decode predictions and ground truth
from einops import rearrange

preds_decodeds, trues_decodeds = [], []
with torch.no_grad():
    bs = 4
    # for i in range(0, preds.shape[0], bs):
    for i in range(0, 4, bs):
        print(f"Decoding batch {i} to {i+bs} / {preds.shape[0]}")

        preds_subset = preds[i*bs:i*bs+bs]
        trues_subset = trues[i*bs:i*bs+bs]

        # preds shape is likely (B, T, H_lat, W_lat, C_lat)
        # We need to flatten B and T, and move C to the second dimension for the
        # decoder:
        #   - (B*T, C, H, W)
        preds_flat = rearrange(preds_subset, "b t h w c -> (b t) c h w")
        trues_flat = rearrange(trues_subset, "b t h w c -> (b t) c h w")


        print(f"Decoding shape: {preds_flat.shape}")

        # Pass noisy=False to deterministically decode
        preds_decoded = autoencoder.decode(preds_flat, noisy=False)
        trues_decoded = autoencoder.decode(trues_flat, noisy=False)
        preds_decodeds.append(preds_decoded)
        trues_decodeds.append(trues_decoded)

preds_decodeds = torch.cat(preds_decodeds, dim=0)
trues_decodeds = torch.cat(trues_decodeds, dim=0)

In [None]:
preds_decodeds.shape

In [None]:
preds.shape

In [None]:
trues_decoded.shape

In [None]:
plt.imshow(trues_decoded[0, 3, ...])

In [None]:

# decode output is (B*T, C_out, H_out, W_out)
# We must use n_samples for b, not preds.shape[0]
preds_decoded = rearrange(
    preds_decoded, "(b t) c h w -> b t w h c", b=preds_decoded.shape[0] // preds.shape[1], t=preds.shape[1]
)
trues_decoded = rearrange(
    trues_decoded, "(b t) c h w -> b t w h c", b=trues_decoded.shape[0] // trues.shape[1], t=trues.shape[1]
)

print(f"Decoded Predictions shape: {preds_decoded.shape}")

# Plot decoded
anim = plot_spatiotemporal_video(
    true=trues_decoded,
    pred=preds_decoded,
    batch_idx=0,
    save_path=None,
    title="Decoded Prediction",
    colorbar_mode="row",
)
HTML(anim.to_jshtml())

In [None]:
## load decoded videos for better visualization
