In [None]:
from omegaconf import OmegaConf
from hydra.core.global_hydra import GlobalHydra
from hydra.utils import instantiate
from src.models.autoencoder import AutoencoderLitModule
from src.models.physics import PhysicsLitModule
from src.utils import animate

import os
import random
import sys
from pathlib import Path
from IPython.display import HTML
from tqdm import tqdm
import torch
from torch_geometric.data import Data
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation as animation
matplotlib.rcParams['animation.embed_limit'] = 400  # set 100 MB limit for animations

import numpy as np
import einops
from functools import partial
from src.datasets.particle_datamodule import ParticleDataModule
from src.utils.metric import mean_iou

GlobalHydra.instance().clear()

os.environ["PROJECT_ROOT"] = os.path.abspath(".")

SEED = 42

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Force deterministic algorithms (and fail if a nondet op is hit)
torch.use_deterministic_algorithms(True)          # raises on nondeterministic ops
torch.set_deterministic_debug_mode("error")       # PyTorch ≥2.1 alias: set_deterministic_debug
# cuDNN settings
torch.backends.cudnn.deterministic = True         # force det conv algos
torch.backends.cudnn.benchmark = False            # disable autotune nondet selection
# Keep math mode consistent
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

In [None]:
def plot_metric(mse, title: str, log: bool = False):
    plt.clf()
    fig, ax = plt.subplots(figsize=(10, 5))
    ax.plot(mse, label='Metric')
    ax.set_xlabel('Timestep')
    ax.set_ylabel('Metric')
    ax.set_title(f"{title} Metric Over Time")
    ax.legend()
    if log: ax.set_yscale('log')
    plt.tight_layout()
    plt.show()

def plot_two_series(y1: np.ndarray,
                    y2: np.ndarray,
                    labels=('Series 1', 'Series 2'),
                    xlabel: str = 'Index',
                    ylabel: str = 'Value',
                    title: str = 'Two Series on One Plot',
                    log: bool = False) -> None:
    """
    Plots two same-length sequences/arrays on the same axes.

    Parameters
    ----------
    y1 : np.ndarray
        First data series (plotted in default style).
    y2 : np.ndarray
        Second data series (plotted in dashed style).
    labels : tuple of str, optional
        Labels for the two series (default ('Series 1', 'Series 2')).
    xlabel : str, optional
        Label for the x-axis (default 'Index').
    ylabel : str, optional
        Label for the y-axis (default 'Value').
    title : str, optional
        Title of the plot (default 'Two Series on One Plot').

    Raises
    ------
    ValueError
        If y1 and y2 are not the same length.
    """
    if len(y1) != len(y2):
        raise ValueError(f"Input arrays must have the same length; got {len(y1)} and {len(y2)}")

    x = np.arange(len(y1))              # common x-axis
    plt.figure()                        # new figure
    plt.plot(x, y1, label=labels[0])    # first series
    plt.plot(x, y2, label=labels[1])  # second series with dashed line
    plt.xlabel(xlabel)                  # x-axis label
    plt.ylabel(ylabel)                  # y-axis label
    plt.title(title)                    # plot title
    plt.legend()                        # show legend
    plt.grid(True)                      # optional grid
    if log: plt.yscale('log')           # set y-axis to logarithmic scale if requested
    plt.tight_layout()                  # nicely fit elements
    plt.show()                          # display

def compute_mse(preds, targets):
    preds = einops.rearrange(preds, 't n c -> t (n c)')
    targets = einops.rearrange(targets, 't n c -> t (n c)')
    mse = torch.mean((preds - targets) ** 2, dim=1).cpu().numpy()
    return mse

In [None]:
from __future__ import annotations
from typing import Sequence, Tuple, Optional
import numpy as np
import matplotlib.pyplot as plt
from itertools import cycle

# Optional torch support (avoid hard dependency)
try:
    import torch  # type: ignore
    _HAS_TORCH = True
except Exception:
    _HAS_TORCH = False


def _as_list(arrs: Sequence[np.ndarray] | np.ndarray) -> list[np.ndarray]:
    if isinstance(arrs, np.ndarray) and arrs.ndim == 1:
        return [arrs]
    return [np.asarray(a).ravel() for a in arrs]


def _as_list_optional(arrs: Optional[Sequence[np.ndarray] | np.ndarray]) -> Optional[list[Optional[np.ndarray]]]:
    if arrs is None:
        return None
    if isinstance(arrs, np.ndarray) and arrs.ndim == 1:
        return [arrs]
    out: list[Optional[np.ndarray]] = []
    for a in arrs:
        out.append(None if a is None else np.asarray(a).ravel())
    return out


def _get_kw(kws: dict | Sequence[dict] | None, i: int, base: dict) -> dict:
    if kws is None:
        return base.copy()
    if isinstance(kws, dict):
        out = base.copy(); out.update(kws); return out
    out = base.copy()
    if i < len(kws) and kws[i] is not None:
        out.update(kws[i])
    return out


def plot_means_variances(
        means: Sequence[np.ndarray] | np.ndarray,
        stds:  Optional[Sequence[np.ndarray] | np.ndarray] = None,
        xlabel: str = "",
        ylabel: str = "",
        ax: plt.Axes | None = None,
        *,
        labels: Sequence[str] | None = None,
        colors: Sequence[str] | None = None,
        upper_bound: float | None = None,
        lower_bound: float | None = None,
        line_kw: dict | Sequence[dict] | None = None,
        fill_kw: dict | Sequence[dict] | None = None,
        lower_value: float | None = None,
        figsize: Tuple[float, float] | None = None,
        dpi: int | None = None,
        start_idx: float = 0.0,
        stride: float = 1.0,
        # NEW legend controls
        legend_outside: bool = False,
        legend_pos: str = "right",         # "right" | "left" | "bottom" | "top"
        legend_ncol: int = 1,
        legend_kw: Optional[dict] = None,
    ) -> plt.Axes:
    """
    Plot one or more mean time series, optionally with ±1 SD envelopes.

    Parameters
    ----------
    means : seq[1D array] or 1D array
    stds : seq[1D array] or 1D array or None
        If None (global) or stds[i] is None (per-series), no bounds are drawn.
    labels : seq[str], optional
    colors : seq[str], optional
    upper_bound, lower_bound : float, optional
        Clip envelopes where stds are present.
    line_kw, fill_kw : dict or seq[dict], optional
    lower_value : float, optional
    figsize, dpi : used only if `ax` is None.
    start_idx : float, default 0.0
        X-value at the first point (global for all series).
    stride : float, default 1.0
        Increment added to x for each step (global for all series).
    legend_outside : bool, default False
        If True, places the legend outside the axes and adjusts margins when
        this function created the axes.
    legend_pos : {"right","left","bottom","top"}, default "right"
        Side on which to park an outside legend.
    legend_ncol : int, default 1
        Number of columns in the legend.
    legend_kw : dict, optional
        Additional kwargs forwarded to `ax.legend(...)`.
    """
    if _HAS_TORCH and isinstance(means, torch.Tensor):
        means = means.numpy(force=True)
    if _HAS_TORCH and isinstance(stds, torch.Tensor):
        stds = stds.numpy(force=True)

    m_list = _as_list(means)
    s_list = _as_list_optional(stds)

    if s_list is not None and len(m_list) != len(s_list):
        raise ValueError(f"Got {len(m_list)} mean series but {len(s_list)} std series.")
    if s_list is not None:
        for i, (m, s) in enumerate(zip(m_list, s_list)):
            if s is not None and m.shape != s.shape:
                raise ValueError(f"Series {i} mean/std length mismatch: {m.shape} vs {s.shape}.")

    created_ax = False
    if ax is None:
        _figsize = figsize if figsize is not None else (6, 4)
        _dpi     = dpi if dpi is not None else 200
        _, ax = plt.subplots(figsize=_figsize, dpi=_dpi)
        created_ax = True

    # color handling
    if colors is None:
        prop = plt.rcParams.get("axes.prop_cycle", None)
        base_colors = [d.get("color") for d in (prop or [])] if prop else None
        if not base_colors:
            base_colors = ["C0", "C1", "C2", "C3", "C4", "C5"]
    else:
        base_colors = list(colors)
    color_cycler = cycle(base_colors)

    base_line_kw = dict(lw=2.2, zorder=3)
    base_fill_kw = dict(alpha=0.18, linewidth=0, zorder=1)

    for i, m in enumerate(m_list):
        s = None if s_list is None else s_list[i]
        x = start_idx + stride * np.arange(m.shape[0])  # <-- stride + start
        color = next(color_cycler)

        lk = _get_kw(line_kw, i, base_line_kw)
        fk = _get_kw(fill_kw, i, base_fill_kw)

        if labels is not None:
            lk = {**lk, "label": labels[i] if i < len(labels) else f"series {i+1}"}

        # mean line
        ax.plot(x, m, color=color, **lk)

        # optional envelopes
        if s is not None:
            upper = m + s
            lower = m - s
            if lower_bound is not None:
                lower = np.clip(lower, a_min=lower_bound, a_max=None)
            if upper_bound is not None:
                upper = np.clip(upper, a_min=None, a_max=upper_bound)
            ax.plot(x, upper, color=color, alpha=0.55, lw=1.2, zorder=2)
            ax.plot(x, lower, color=color, alpha=0.55, lw=1.2, zorder=2)
            ax.fill_between(x, lower, upper, color=color, **fk)

    # cosmetics
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.margins(x=0.02)
    ax.grid(True, which="both", ls=":", lw=0.5, zorder=0)

    if lower_value is not None:
        ymin, ymax = ax.get_ylim()
        ax.set_ylim(bottom=lower_value, top=ymax)

    if labels is not None:
        base_leg_kw = dict(frameon=False, ncol=legend_ncol)
        if legend_kw:
            base_leg_kw.update(legend_kw)

        if legend_outside:
            if legend_pos in ("right", "left"):
                loc = "center left" if legend_pos == "right" else "center right"
                x_anchor = 1.02 if legend_pos == "right" else -0.02
                ax.legend(loc=loc, bbox_to_anchor=(x_anchor, 0.5),
                          borderaxespad=0.0, **base_leg_kw)
                if created_ax:
                    if legend_pos == "right":
                        ax.figure.subplots_adjust(right=0.78)
                    else:
                        ax.figure.subplots_adjust(left=0.22)
            else:
                loc = "lower center" if legend_pos == "bottom" else "upper center"
                y_anchor = -0.02 if legend_pos == "bottom" else 1.02
                ax.legend(loc=loc, bbox_to_anchor=(0.5, y_anchor),
                          borderaxespad=0.0, **base_leg_kw)
                if created_ax:
                    if legend_pos == "bottom":
                        ax.figure.subplots_adjust(bottom=0.20)
                    else:
                        ax.figure.subplots_adjust(top=0.88)
        else:
            ax.legend(**base_leg_kw)

    return ax

# n_fields = 2
## n_jump = 1

### n_skip = 1 from practical work

In [None]:
cfg          = OmegaConf.load(
    "logs/train/runs/2025-09-04_07-13-04/.hydra/config.yaml"
)

# factories
latent_model_factory = instantiate(cfg.model.latent_model, _partial_=True)
model_factory        = instantiate(cfg.model,            _partial_=True)
loss_function        = instantiate(cfg.model.loss_function)

lit_model = PhysicsLitModule.load_from_checkpoint(
    "logs/train/runs/2025-09-04_07-13-04/waterdrop_physics/w55pl0j7/checkpoints/epoch=15-step=239520.ckpt",
    latent_model = latent_model_factory,
    model        = model_factory,
    loss_function= loss_function,
    strict       = True,
    map_location = "cuda"
)
lit_model.eval()
dataset_PH = instantiate(cfg.data)
dataset_PH.setup(stage="physics")

In [None]:
TRAJ_IDX = list(range(dataset_PH.get_dataset(split="test").n_traj))
n_per_traj_PH = dataset_PH.get_dataset(split="test").n_per_traj
IDX_PH = [[0]] * len(TRAJ_IDX)

In [None]:
rollout_physics, fields, GT_physics, GT_vel_fields_normalized_physics, MSE_fields_normalized_physics, MSE_Field_latent_physics, propagated_latents = lit_model.rollout_physics(
    particle_dm=dataset_PH,
    traj_idx=TRAJ_IDX,
    idx=IDX_PH,
    query_gt_pos=False,
    split="test",
    use_gt_field=False,
)

In [None]:
ious = []
for (GT, rollout) in tqdm(zip(GT_physics, rollout_physics), total=len(GT_physics)):
    iou_traj = []
    for (GT_pos, rollout_pos) in zip(GT, rollout):
        # Calculate IoU for each pair of GT and rollout positions
        iou_traj.append(mean_iou(rollout_pos.unsqueeze(0), GT_pos.unsqueeze(0), n_compartments=64, bounding_box=(-0.9, 1.9, -0.9, 1.9)))
    ious.append(torch.stack(iou_traj))

In [None]:
ious = torch.stack(ious, dim=0)

In [None]:
ious.shape

In [None]:
ious_mean_2_1_1 = ious.mean(axis=0)
ious_std_2_1_1 = ious.std(axis=0)

In [None]:
torch.save(ious_mean_2_1_1, "D:/Projects/Master/UPT/data/ious/ious_mean_2_1_1.pth")
torch.save(ious_std_2_1_1, "D:/Projects/Master/UPT/data/ious/ious_std_2_1_1.pth")
# ious_mean_2_1_1 = torch.load("D:/Projects/Master/UPT/data/ious/ious_mean_2_1_1.pth")
# ious_std_2_1_1 = torch.load("D:/Projects/Master/UPT/data/ious/ious_std_2_1_1.pth")

In [None]:
plot_means_variances(
    ious_mean_2_1_1,
    ious_std_2_1_1,
    xlabel="Rollout Timestep",
    ylabel="IoU",
    upper_bound=1.0,
    lower_bound=0.0,
)
plt.show()

In [None]:
cfg          = OmegaConf.load(
    Path("logs/train/runs/2025-09-04_10-37-42/.hydra/config.yaml")
)

# factories
latent_model_factory = instantiate(cfg.model.latent_model, _partial_=True)
model_factory        = instantiate(cfg.model,            _partial_=True)
loss_function        = instantiate(cfg.model.loss_function)

lit_model = PhysicsLitModule.load_from_checkpoint(
    Path("logs/train/runs/2025-09-04_10-37-42/waterdrop_physics/jo5ef1v0/checkpoints/epoch=22-step=343965.ckpt"),
    latent_model = latent_model_factory,
    model        = model_factory,
    loss_function= loss_function,
    strict       = True,
    map_location = "cuda"
)
lit_model.eval()
dataset_PH = instantiate(cfg.data)
dataset_PH.setup(stage="physics")

dataset_PH.test_dataset.rollout = True

In [None]:
TRAJ_IDX = list(range(dataset_PH.get_dataset(split="test").n_traj))
n_per_traj_PH = dataset_PH.get_dataset(split="test").n_per_traj
IDX_PH = [[0]] * len(TRAJ_IDX)

In [None]:
rollout_physics, fields, GT_physics, GT_vel_fields_normalized_physics, MSE_fields_normalized_physics, MSE_Field_latent_physics, propagated_latents = lit_model.rollout_physics(
    particle_dm=dataset_PH,
    traj_idx=TRAJ_IDX,
    idx=IDX_PH,
    query_gt_pos=False,
    split="test",
    use_gt_field=False,
)

In [None]:
ious = []
for (GT, rollout) in tqdm(zip(GT_physics, rollout_physics), total=len(GT_physics)):
    iou_traj = []
    for (GT_pos, rollout_pos) in zip(GT, rollout):
        # Calculate IoU for each pair of GT and rollout positions
        iou_traj.append(mean_iou(rollout_pos.unsqueeze(0), GT_pos.unsqueeze(0), n_compartments=64, bounding_box=(-0.9, 1.9, -0.9, 1.9)))
    ious.append(torch.stack(iou_traj))

In [None]:
ious = torch.stack(ious, dim=0)

In [None]:
ious.shape

In [None]:
ious_mean_2_1 = ious.mean(axis=0)
ious_std_2_1 = ious.std(axis=0)

In [None]:
plot_means_variances(
    ious_mean_2_1,
    ious_std_2_1,
    xlabel="Rollout Timestep",
    ylabel="IoU",
    upper_bound=1.0,
    lower_bound=0.0,
)
plt.show()

## n_jump = 2

In [None]:
cfg          = OmegaConf.load(
    Path("logs/train/runs/2025-09-03_21-24-57/.hydra/config.yaml")
)

# factories
latent_model_factory = instantiate(cfg.model.latent_model, _partial_=True)
model_factory        = instantiate(cfg.model,            _partial_=True)
loss_function        = instantiate(cfg.model.loss_function)

lit_model = PhysicsLitModule.load_from_checkpoint(
    Path("logs/train/runs/2025-09-03_21-24-57/waterdrop_physics/lw2bv57w/checkpoints/epoch=27-step=417900.ckpt"),
    latent_model = latent_model_factory,
    model        = model_factory,
    loss_function= loss_function,
    strict       = True,
    map_location = "cuda"
)
lit_model.eval()
dataset_PH = instantiate(cfg.data)
dataset_PH.setup(stage="physics")

dataset_PH.test_dataset.rollout = True

In [None]:
TRAJ_IDX = list(range(dataset_PH.get_dataset(split="test").n_traj))
n_per_traj_PH = dataset_PH.get_dataset(split="test").n_per_traj
IDX_PH = [[0]] * len(TRAJ_IDX)

In [None]:
rollout_physics, fields, GT_physics, GT_vel_fields_normalized_physics, MSE_fields_normalized_physics, MSE_Field_latent_physics, propagated_latents = lit_model.rollout_physics(
    particle_dm=dataset_PH,
    traj_idx=TRAJ_IDX,
    idx=IDX_PH,
    query_gt_pos=False,
    split="test",
    use_gt_field=False,
)

In [None]:
ious = []
for (GT, rollout) in tqdm(zip(GT_physics, rollout_physics), total=len(GT_physics)):
    iou_traj = []
    for (GT_pos, rollout_pos) in zip(GT, rollout):
        # Calculate IoU for each pair of GT and rollout positions
        iou_traj.append(mean_iou(rollout_pos.unsqueeze(0), GT_pos.unsqueeze(0), n_compartments=64, bounding_box=(-0.9, 1.9, -0.9, 1.9)))
    ious.append(torch.stack(iou_traj))

In [None]:
ious = torch.stack(ious, dim=0)

In [None]:
ious.shape

In [None]:
ious_mean_2_2 = ious.mean(axis=0)
ious_std_2_2 = ious.std(axis=0)

In [None]:
plot_means_variances(
    ious_mean_2_2,
    ious_std_2_2,
    xlabel="Rollout Timestep",
    ylabel="IoU",
    upper_bound=1.0,
    lower_bound=0.0,
)
plt.show()

## n_jumps = 4

In [None]:
cfg          = OmegaConf.load(
    Path("logs/train/runs/2025-09-02_20-19-34/.hydra/config.yaml")
)

# factories
latent_model_factory = instantiate(cfg.model.latent_model, _partial_=True)
model_factory        = instantiate(cfg.model,            _partial_=True)
loss_function        = instantiate(cfg.model.loss_function)

lit_model = PhysicsLitModule.load_from_checkpoint(
    Path("logs/train/runs/2025-09-02_20-19-34/waterdrop_physics/ovvesbxq/checkpoints/epoch=23-step=356760.ckpt"),
    latent_model = latent_model_factory,
    model        = model_factory,
    loss_function= loss_function,
    strict       = True,
    map_location = "cuda"
)
lit_model.eval()
dataset_PH = instantiate(cfg.data)
dataset_PH.setup(stage="physics")

dataset_PH.test_dataset.rollout = True

In [None]:
TRAJ_IDX = list(range(dataset_PH.get_dataset(split="test").n_traj))
n_per_traj_PH = dataset_PH.get_dataset(split="test").n_per_traj
IDX_PH = [[0]] * len(TRAJ_IDX)

In [None]:
rollout_physics, retrieved_fields, GT_physics, GT_vel_fields_normalized_physics, MSE_fields_normalized_physics, MSE_Field_latent_physics, propagated_latents = lit_model.rollout_physics(
    particle_dm=dataset_PH,
    traj_idx=TRAJ_IDX,
    idx=IDX_PH,
    query_gt_pos=False,
    split="test",
    use_gt_field=False,
)

In [None]:
ious = []
for (GT, rollout) in tqdm(zip(GT_physics, rollout_physics), total=len(GT_physics)):
    iou_traj = []
    for (GT_pos, rollout_pos) in zip(GT, rollout):
        # Calculate IoU for each pair of GT and rollout positions
        iou_traj.append(mean_iou(rollout_pos.unsqueeze(0), GT_pos.unsqueeze(0), n_compartments=64, bounding_box=(-0.9, 1.9, -0.9, 1.9)))
    ious.append(torch.stack(iou_traj))

In [None]:
ious = torch.stack(ious, dim=0)

In [None]:
ious.shape

In [None]:
ious_mean_2_4 = ious.mean(axis=0)
ious_std_2_4 = ious.std(axis=0)

In [None]:
plot_means_variances(
    ious_mean_2_4,
    ious_std_2_4,
    xlabel="Rollout Timestep",
    ylabel="IoU",
    upper_bound=1.0,
    lower_bound=0.0,
)
plt.show()

In [None]:
TRAJ_IDX = [0, 1, 2]
IDX_PH = [[0, 50], [0, 80], [0, 70]]

In [None]:
rollout_physics, fields, GT_physics, GT_vel_fields_normalized_physics, MSE_fields_normalized_physics, MSE_Field_latent_physics, propagated_latents = lit_model.rollout_physics(
    particle_dm=dataset_PH,
    traj_idx=TRAJ_IDX,
    idx=IDX_PH,
    query_gt_pos=False,
    split="test",
    use_gt_field=False,
)

In [None]:
plot_two_series(
    MSE_fields_normalized_physics[0].numpy(force=True),
    MSE_fields_normalized_physics[0].numpy(force=True),
    log=True
)

In [None]:
ani = animate(
    rollout_physics[0],
    ground_truth=GT_physics[0],
    ref_frame=((0, 1), (0, 1)),
    start_idx=0,
    n_skip_ahead_timesteps=2
)
HTML(ani.to_jshtml())

In [None]:
ani = animate(
    rollout_physics[2],
    ground_truth=GT_physics[2],
    ref_frame=((0, 1), (0, 1)),
    start_idx=0,
    n_skip_ahead_timesteps=2
)
HTML(ani.to_jshtml())

In [None]:
ani = animate(
    rollout_physics[4],
    ground_truth=GT_physics[4],
    ref_frame=((0, 1), (0, 1)),
    start_idx=0,
    n_skip_ahead_timesteps=2
)
HTML(ani.to_jshtml())

In [None]:
from matplotlib.animation import PillowWriter

fps = 40

ani.save(
    "rollout.gif",
    writer=PillowWriter(fps=fps),
    dpi=150,  # lower for smaller files, e.g., 100
    savefig_kwargs={"facecolor": "white"}  # avoid dark/transparent backgrounds
)

In [None]:
print(GT_vel_fields_normalized_physics[4].sum(dim=-1).shape)
print(rollout_physics[4].shape)
print(fields[4].shape)
print(GT_physics[4].shape)

In [None]:
torch.allclose(dataset_PH.test_dataset.unnormalize(fields[4]).sum(dim=-2)[0], dataset_PH.test_dataset.unnormalize(fields[4]).sum(dim=-2)[1])

In [None]:
dataset_PH.test_dataset.unnormalize(fields[4]).sum(dim=-2).shape

In [None]:
ani = animate(
    rollout_physics[4],
    vectors=fields[4].sum(dim=-2),
    vector_positions=rollout_physics[4],
    ref_frame=((0, 0.6), (0, 0.6)),
    start_idx=0,
    n_skip_ahead_timesteps=2,
    dpi=150
)

In [None]:
HTML(ani.to_jshtml())

In [None]:
from itertools import islice

def grab_frames(ani, N):
    fig = ani._fig
    frames = []
    # iterate the same sequence the animation would use
    for f in islice(ani.new_frame_seq(), N):
        # draw one frame (avoid blit complications)
        ani._draw_next_frame(framedata=f, blit=False)   # private API but stable
        fig.canvas.draw()
        # read back pixels
        w, h = fig.canvas.get_width_height()
        img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(h, w, 3)
        frames.append(img.copy())
    return frames

In [None]:
from itertools import islice
from matplotlib.animation import PillowWriter

N = 120
fps = 40

writer = PillowWriter(fps=fps)
with writer.saving(ani._fig, "media/boundary.gif", dpi=150):
    for f in islice(ani.new_frame_seq(), N):
        ani._draw_next_frame(f, blit=False)      # draw one frame
        writer.grab_frame(facecolor="white")     # same as savefig_kwargs={"facecolor":"white"}

# n_fields = 4
## n_jumps = 1

In [None]:
cfg          = OmegaConf.load(
    Path("logs/train/runs/2025-09-03_20-37-06/.hydra/config.yaml")
)

# factories
latent_model_factory = instantiate(cfg.model.latent_model, _partial_=True)
model_factory        = instantiate(cfg.model,            _partial_=True)
loss_function        = instantiate(cfg.model.loss_function)

lit_model = PhysicsLitModule.load_from_checkpoint(
    Path("logs/train/runs/2025-09-03_20-37-06/waterdrop_physics/zcl1xvms/checkpoints/epoch=24-step=372375.ckpt"),
    latent_model = latent_model_factory,
    model        = model_factory,
    loss_function= loss_function,
    strict       = True,
    map_location = "cuda"
)
lit_model.eval()
dataset_PH = instantiate(cfg.data)
dataset_PH.setup(stage="physics")

dataset_PH.test_dataset.rollout = True

In [None]:
TRAJ_IDX = list(range(dataset_PH.get_dataset(split="test").n_traj))
n_per_traj_PH = dataset_PH.get_dataset(split="test").n_per_traj
IDX_PH = [[0]] * len(TRAJ_IDX)

In [None]:
rollout_physics, fields, GT_physics, GT_vel_fields_normalized_physics, MSE_fields_normalized_physics, MSE_Field_latent_physics, propagated_latents = lit_model.rollout_physics(
    particle_dm=dataset_PH,
    traj_idx=TRAJ_IDX,
    idx=IDX_PH,
    query_gt_pos=False,
    split="test",
    use_gt_field=False,
)

In [None]:
ious = []
for (GT, rollout) in tqdm(zip(GT_physics, rollout_physics), total=len(GT_physics)):
    iou_traj = []
    for (GT_pos, rollout_pos) in zip(GT, rollout):
        # Calculate IoU for each pair of GT and rollout positions
        iou_traj.append(mean_iou(rollout_pos.unsqueeze(0), GT_pos.unsqueeze(0), n_compartments=64, bounding_box=(-0.9, 1.9, -0.9, 1.9)))
    ious.append(torch.stack(iou_traj))

In [None]:
ious = torch.stack(ious, dim=0)

In [None]:
ious.shape

In [None]:
ious_mean_4_1 = ious.mean(axis=0)
ious_std_4_1 = ious.std(axis=0)

In [None]:
plot_means_variances(
    ious_mean_4_1,
    ious_std_4_1,
    xlabel="Rollout Timestep",
    ylabel="IoU",
    upper_bound=1.0,
    lower_bound=0.0,
)
plt.show()

## n_jumps = 2

In [None]:
cfg          = OmegaConf.load(
    Path("logs/train/runs/2025-09-07_14-23-41/.hydra/config.yaml")
)

# factories
latent_model_factory = instantiate(cfg.model.latent_model, _partial_=True)
model_factory        = instantiate(cfg.model,            _partial_=True)
loss_function        = instantiate(cfg.model.loss_function)

lit_model = PhysicsLitModule.load_from_checkpoint(
    Path("logs/train/runs/2025-09-07_14-23-41/waterdrop_physics/nvczsxwy/checkpoints/epoch=27-step=415380.ckpt"),
    latent_model = latent_model_factory,
    model        = model_factory,
    loss_function= loss_function,
    strict       = True,
    map_location = "cuda"
)
lit_model.eval()
dataset_PH = instantiate(cfg.data)
dataset_PH.setup(stage="physics")

dataset_PH.test_dataset.rollout = True

In [None]:
TRAJ_IDX = list(range(dataset_PH.get_dataset(split="test").n_traj))
n_per_traj_PH = dataset_PH.get_dataset(split="test").n_per_traj
IDX_PH = [[0]] * len(TRAJ_IDX)

In [None]:
rollout_physics, fields, GT_physics, GT_vel_fields_normalized_physics, MSE_fields_normalized_physics, MSE_Field_latent_physics, propagated_latents = lit_model.rollout_physics(
    particle_dm=dataset_PH,
    traj_idx=TRAJ_IDX,
    idx=IDX_PH,
    query_gt_pos=False,
    split="test",
    use_gt_field=False,
)

In [None]:
ious = []
for (GT, rollout) in tqdm(zip(GT_physics, rollout_physics), total=len(GT_physics)):
    iou_traj = []
    for (GT_pos, rollout_pos) in zip(GT, rollout):
        # Calculate IoU for each pair of GT and rollout positions
        iou_traj.append(mean_iou(rollout_pos.unsqueeze(0), GT_pos.unsqueeze(0), n_compartments=64, bounding_box=(-0.9, 1.9, -0.9, 1.9)))
    ious.append(torch.stack(iou_traj))

In [None]:
ious = torch.stack(ious, dim=0)

In [None]:
ious.shape

In [None]:
ious_mean_4_2 = ious.mean(axis=0)
ious_std_4_2 = ious.std(axis=0)

In [None]:
plot_means_variances(
    ious_mean_4_2,
    ious_std_4_2,
    xlabel="Rollout Timestep",
    ylabel="IoU",
    upper_bound=1.0,
    lower_bound=0.0,
)
plt.show()

## n_jumps = 4

In [None]:
cfg          = OmegaConf.load(
    Path("logs/train/runs/2025-09-06_23-29-42/.hydra/config.yaml")
)

# factories
latent_model_factory = instantiate(cfg.model.latent_model, _partial_=True)
model_factory        = instantiate(cfg.model,            _partial_=True)
loss_function        = instantiate(cfg.model.loss_function)

lit_model = PhysicsLitModule.load_from_checkpoint(
    Path("logs/train/runs/2025-09-06_23-29-42/waterdrop_physics/eygag9o2/checkpoints/epoch=27-step=412020.ckpt"),
    latent_model = latent_model_factory,
    model        = model_factory,
    loss_function= loss_function,
    strict       = True,
    map_location = "cuda"
)
lit_model.eval()
dataset_PH = instantiate(cfg.data)
dataset_PH.setup(stage="physics")

dataset_PH.test_dataset.rollout = True

In [None]:
TRAJ_IDX = list(range(dataset_PH.get_dataset(split="test").n_traj))
n_per_traj_PH = dataset_PH.get_dataset(split="test").n_per_traj
IDX_PH = [[0]] * len(TRAJ_IDX)

In [None]:
rollout_physics, fields, GT_physics, GT_vel_fields_normalized_physics, MSE_fields_normalized_physics, MSE_Field_latent_physics, propagated_latents = lit_model.rollout_physics(
    particle_dm=dataset_PH,
    traj_idx=TRAJ_IDX,
    idx=IDX_PH,
    query_gt_pos=False,
    split="test",
    use_gt_field=False,
)

In [None]:
ious = []
for (GT, rollout) in tqdm(zip(GT_physics, rollout_physics), total=len(GT_physics)):
    iou_traj = []
    for (GT_pos, rollout_pos) in zip(GT, rollout):
        # Calculate IoU for each pair of GT and rollout positions
        iou_traj.append(mean_iou(rollout_pos.unsqueeze(0), GT_pos.unsqueeze(0), n_compartments=64, bounding_box=(-0.9, 1.9, -0.9, 1.9)))
    ious.append(torch.stack(iou_traj))

In [None]:
ious = torch.stack(ious, dim=0)

In [None]:
ious.shape

In [None]:
ious_mean_4_4 = ious.mean(axis=0)
ious_std_4_4 = ious.std(axis=0)

In [None]:
plot_means_variances(
    ious_mean_4_4,
    ious_std_4_4,
    xlabel="Rollout Timestep",
    ylabel="IoU",
    upper_bound=1.0,
    lower_bound=0.0,
)
plt.show()

# n_fields = 8
## n_jumps = 1

In [None]:
cfg          = OmegaConf.load(
    Path("logs/train/runs/2025-09-05_13-50-41/.hydra/config.yaml")
)

# factories
latent_model_factory = instantiate(cfg.model.latent_model, _partial_=True)
model_factory        = instantiate(cfg.model,            _partial_=True)
loss_function        = instantiate(cfg.model.loss_function)

lit_model = PhysicsLitModule.load_from_checkpoint(
    Path("logs/train/runs/2025-09-05_13-50-41/waterdrop_physics/k8sfxcss/checkpoints/epoch=22-step=339825.ckpt"),
    latent_model = latent_model_factory,
    model        = model_factory,
    loss_function= loss_function,
    strict       = True,
    map_location = "cuda"
)
lit_model.eval()
dataset_PH = instantiate(cfg.data)
dataset_PH.setup(stage="physics")

dataset_PH.test_dataset.rollout = True

In [None]:
TRAJ_IDX = list(range(dataset_PH.get_dataset(split="test").n_traj))
n_per_traj_PH = dataset_PH.get_dataset(split="test").n_per_traj
IDX_PH = [[0]] * len(TRAJ_IDX)

In [None]:
rollout_physics, fields, GT_physics, GT_vel_fields_normalized_physics, MSE_fields_normalized_physics, MSE_Field_latent_physics, propagated_latents = lit_model.rollout_physics(
    particle_dm=dataset_PH,
    traj_idx=TRAJ_IDX,
    idx=IDX_PH,
    query_gt_pos=False,
    split="test",
    use_gt_field=False,
)

In [None]:
ious = []
for (GT, rollout) in tqdm(zip(GT_physics, rollout_physics), total=len(GT_physics)):
    iou_traj = []
    for (GT_pos, rollout_pos) in zip(GT, rollout):
        # Calculate IoU for each pair of GT and rollout positions
        iou_traj.append(mean_iou(rollout_pos.unsqueeze(0), GT_pos.unsqueeze(0), n_compartments=64, bounding_box=(-0.9, 1.9, -0.9, 1.9)))
    ious.append(torch.stack(iou_traj))

In [None]:
ious = torch.stack(ious, dim=0)

In [None]:
ious.shape

In [None]:
ious_mean_8_1 = ious.mean(axis=0)
ious_std_8_1 = ious.std(axis=0)

In [None]:
plot_means_variances(
    ious_mean_8_1,
    ious_std_8_1,
    xlabel="Rollout Timestep",
    ylabel="IoU",
    upper_bound=1.0,
    lower_bound=0.0,
)
plt.show()

## n_jumps = 2

In [None]:
cfg          = OmegaConf.load(
    Path("logs/train/runs/2025-09-05_22-24-52/.hydra/config.yaml")
)

# factories
latent_model_factory = instantiate(cfg.model.latent_model, _partial_=True)
model_factory        = instantiate(cfg.model,            _partial_=True)
loss_function        = instantiate(cfg.model.loss_function)

lit_model = PhysicsLitModule.load_from_checkpoint(
    Path("logs/train/runs/2025-09-05_22-24-52/waterdrop_physics/7jay9bnh/checkpoints/epoch=26-step=395685.ckpt"),
    latent_model = latent_model_factory,
    model        = model_factory,
    loss_function= loss_function,
    strict       = True,
    map_location = "cuda"
)
lit_model.eval()
dataset_PH = instantiate(cfg.data)
dataset_PH.setup(stage="physics")

dataset_PH.test_dataset.rollout = True

In [None]:
TRAJ_IDX = list(range(dataset_PH.get_dataset(split="test").n_traj))
n_per_traj_PH = dataset_PH.get_dataset(split="test").n_per_traj
IDX_PH = [[0]] * len(TRAJ_IDX)

In [None]:
rollout_physics, fields, GT_physics, GT_vel_fields_normalized_physics, MSE_fields_normalized_physics, MSE_Field_latent_physics, propagated_latents = lit_model.rollout_physics(
    particle_dm=dataset_PH,
    traj_idx=TRAJ_IDX,
    idx=IDX_PH,
    query_gt_pos=False,
    split="test",
    use_gt_field=False,
)

In [None]:
ious = []
for (GT, rollout) in tqdm(zip(GT_physics, rollout_physics), total=len(GT_physics)):
    iou_traj = []
    for (GT_pos, rollout_pos) in zip(GT, rollout):
        # Calculate IoU for each pair of GT and rollout positions
        iou_traj.append(mean_iou(rollout_pos.unsqueeze(0), GT_pos.unsqueeze(0), n_compartments=64, bounding_box=(-0.9, 1.9, -0.9, 1.9)))
    ious.append(torch.stack(iou_traj))

In [None]:
ious = torch.stack(ious, dim=0)

In [None]:
ious.shape

In [None]:
ious_mean_8_2 = ious.mean(axis=0)
ious_std_8_2 = ious.std(axis=0)

In [None]:
plot_means_variances(
    ious_mean_8_2,
    ious_std_8_2,
    xlabel="Rollout Timestep",
    ylabel="IoU",
    upper_bound=1.0,
    lower_bound=0.0,
)
plt.show()

## n_jumps = 4

In [None]:
cfg          = OmegaConf.load(
    Path("logs/train/runs/2025-09-06_08-42-16/.hydra/config.yaml")
)

# factories
latent_model_factory = instantiate(cfg.model.latent_model, _partial_=True)
model_factory        = instantiate(cfg.model,            _partial_=True)
loss_function        = instantiate(cfg.model.loss_function)

lit_model = PhysicsLitModule.load_from_checkpoint(
    Path("logs/train/runs/2025-09-06_08-42-16/waterdrop_physics/nu7i8rej/checkpoints/epoch=27-step=403620.ckpt"),
    latent_model = latent_model_factory,
    model        = model_factory,
    loss_function= loss_function,
    strict       = True,
    map_location = "cuda"
)
lit_model.eval()
dataset_PH = instantiate(cfg.data)
dataset_PH.setup(stage="physics")

dataset_PH.test_dataset.rollout = True

In [None]:
num_params = sum(p.numel() for p in lit_model.parameters())
num_params_latent = sum(p.numel() for p in lit_model.latent_model.parameters())
num_params_AE = num_params - num_params_latent

In [None]:
TRAJ_IDX = list(range(dataset_PH.get_dataset(split="test").n_traj))
n_per_traj_PH = dataset_PH.get_dataset(split="test").n_per_traj
IDX_PH = [[0]] * len(TRAJ_IDX)

In [None]:
rollout_physics, fields, GT_physics, GT_vel_fields_normalized_physics, MSE_fields_normalized_physics, MSE_Field_latent_physics, propagated_latents = lit_model.rollout_physics(
    particle_dm=dataset_PH,
    traj_idx=TRAJ_IDX,
    idx=IDX_PH,
    query_gt_pos=False,
    split="test",
    use_gt_field=False,
)

In [None]:
ious = []
for (GT, rollout) in tqdm(zip(GT_physics, rollout_physics), total=len(GT_physics)):
    iou_traj = []
    for (GT_pos, rollout_pos) in zip(GT, rollout):
        # Calculate IoU for each pair of GT and rollout positions
        iou_traj.append(mean_iou(rollout_pos.unsqueeze(0), GT_pos.unsqueeze(0), n_compartments=64, bounding_box=(-0.9, 1.9, -0.9, 1.9)))
    ious.append(torch.stack(iou_traj))

In [None]:
ious = torch.stack(ious, dim=0)

In [None]:
ious.shape

### a look at the worst trajectories

In [None]:
ious[:, 50].min()

In [None]:
ious[:, 52].argmin()

In [None]:
ani = animate( # bad iou around timestep 400
    rollout_physics[11],
    ground_truth=GT_physics[11],
    ref_frame=((0, 1), (0, 1)),
    start_idx=0,
    n_skip_ahead_timesteps=8
)
HTML(ani.to_jshtml())

In [None]:
ani = animate( # bad boundaries still happen
    rollout_physics[10],
    ground_truth=GT_physics[10],
    ref_frame=((0, 1), (0, 1)),
    start_idx=0,
    n_skip_ahead_timesteps=8
)
HTML(ani.to_jshtml())

In [None]:
ani = animate( # bad iou around timestep 416
    rollout_physics[1],
    ground_truth=GT_physics[1],
    ref_frame=((0, 1), (0, 1)),
    start_idx=0,
    n_skip_ahead_timesteps=8
)
HTML(ani.to_jshtml())

In [None]:
ious_mean_8_4 = ious.mean(axis=0)
ious_std_8_4 = ious.std(axis=0)

In [None]:
plot_means_variances(
    ious_mean_8_4,
    ious_std_8_4,
    xlabel="Rollout Timestep",
    ylabel="IoU",
    upper_bound=1.0,
    lower_bound=0.0,
)
plt.show()

In [None]:
ani = animate(
    rollout_physics[5],
    ground_truth=GT_physics[5],
    ref_frame=((0, 1), (0, 1)),
    start_idx=16,
    n_skip_ahead_timesteps=8,
    dpi=150
)
HTML(ani.to_jshtml())

In [None]:
ani = animate(
    rollout_physics[6],
    ground_truth=GT_physics[6],
    ref_frame=((0, 1), (0, 1)),
    start_idx=16,
    n_skip_ahead_timesteps=8,
    dpi=150
)
HTML(ani.to_jshtml())

In [None]:
ani = animate(
    rollout_physics[7],
    ground_truth=GT_physics[7],
    ref_frame=((0, 1), (0, 1)),
    start_idx=16,
    n_skip_ahead_timesteps=8,
    dpi=150
)
HTML(ani.to_jshtml())

In [None]:
ani = animate(
    rollout_physics[8],
    ground_truth=GT_physics[8],
    ref_frame=((0, 1), (0, 1)),
    start_idx=16,
    n_skip_ahead_timesteps=8,
    dpi=150
)
HTML(ani.to_jshtml())

In [None]:
ani = animate(
    rollout_physics[0],
    ground_truth=GT_physics[0],
    ref_frame=((0, 1), (0, 1)),
    start_idx=16,
    n_skip_ahead_timesteps=8
)
HTML(ani.to_jshtml())

In [None]:
ani = animate(
    rollout_physics[1],
    ground_truth=GT_physics[1],
    ref_frame=((0, 1), (0, 1)),
    start_idx=16,
    n_skip_ahead_timesteps=8
)
HTML(ani.to_jshtml())

In [None]:
ani = animate(
    rollout_physics[2],
    ground_truth=GT_physics[2],
    ref_frame=((0, 1), (0, 1)),
    start_idx=0,
    n_skip_ahead_timesteps=8
)
HTML(ani.to_jshtml())

In [None]:
from matplotlib.animation import PillowWriter

fps = 40

ani.save(
    "rollout.gif",
    writer=PillowWriter(fps=fps),
    dpi=150,  # lower for smaller files, e.g., 100
    savefig_kwargs={"facecolor": "white"}  # avoid dark/transparent backgrounds
)

In [None]:
print(GT_vel_fields_normalized_physics[2].sum(dim=-1).shape)
print(rollout_physics[2].shape)
print(fields[2].shape)
print(GT_physics[2].shape)

In [None]:
ani = animate(
    rollout_physics[2],
    vectors=fields[2].sum(dim=-2),
    vector_positions=rollout_physics[2],
    ref_frame=((0, 0.6), (0, 0.6)),
    start_idx=0,
    n_skip_ahead_timesteps=2,
    dpi=150
)

In [None]:
HTML(ani.to_jshtml())

In [None]:
from itertools import islice
from matplotlib.animation import PillowWriter

def grab_frames(ani, N):
    fig = ani._fig
    frames = []
    # iterate the same sequence the animation would use
    for f in islice(ani.new_frame_seq(), N):
        # draw one frame (avoid blit complications)
        ani._draw_next_frame(framedata=f, blit=False)   # private API but stable
        fig.canvas.draw()
        # read back pixels
        w, h = fig.canvas.get_width_height()
        img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(h, w, 3)
        frames.append(img.copy())
    return frames

In [None]:
N = 120
fps = 40

writer = PillowWriter(fps=fps)
with writer.saving(ani._fig, "media/boundary.gif", dpi=150):
    for f in islice(ani.new_frame_seq(), N):
        ani._draw_next_frame(f, blit=False)      # draw one frame
        writer.grab_frame(facecolor="white")     # same as savefig_kwargs={"facecolor":"white"}

In [None]:
ious = {
    "2_1_1": (ious_mean_2_1_1, ious_std_2_1_1, 1),
    "2_1": (ious_mean_2_1, ious_std_2_1, 2),
    "2_2": (ious_mean_2_2, ious_std_2_2, 2),
    "2_4": (ious_mean_2_4, ious_std_2_4, 2),
    "4_1": (ious_mean_4_1, ious_std_4_1, 4),
    "4_2": (ious_mean_4_2, ious_std_4_2, 4),
    "4_4": (ious_mean_4_4, ious_std_4_4, 4),
    "8_1": (ious_mean_8_1, ious_std_8_1, 8),
    "8_2": (ious_mean_8_2, ious_std_8_2, 8),
    "8_4": (ious_mean_8_4, ious_std_8_4, 8)
}

In [None]:
mean_rollout_iou = [iou[0].mean() for iou in ious.values()]

In [None]:
mean_rollout_iou

In [None]:
# torch.save(ious, "D:/Projects/Master/UPT/data/ious/ious.pth")
# torch.save(mean_rollout_iou, "D:/Projects/Master/UPT/data/ious/mean_rollout_iou.pth")
ious = torch.load("D:/Projects/Master/UPT/data/ious/ious.pth")

In [None]:
mean_iou_plot_means = [iou_elem[0][8//iou_elem[2]-1 : : 8//iou_elem[2]] for iou_elem in ious.values()]
mean_iou_plot_stds = [iou_elem[1][8//iou_elem[2]-1 : : 8//iou_elem[2]] for iou_elem in ious.values()]

In [None]:
for iou_plot_list in mean_iou_plot_means:
    print(iou_plot_list.shape)

In [None]:
plot_means_variances(
    mean_iou_plot_means,
    mean_iou_plot_stds,
    xlabel="timestep",
    ylabel="IoU",
    labels=
    [r"$n_{\mathrm{fields}}=2,\, n_{\mathrm{jumps}}=1,\, t_{\mathcal{A}}=1$"]
    + [
        rf"$n_{{\rm fields}}={int(g[0])},\, n_{{\rm jumps}}={int(g[-1])}$"
        for g in ious.keys()
    ][1:],
    figsize=(14, 8),
    upper_bound=1.0,
    dpi=250,
    start_idx=15,
    stride=8
)

n_fields hilft zwar mit 1. Drop, aber im Allgemeinen wird es schlechter!

In [None]:
plot_means_variances(
    mean_iou_plot_means,
    xlabel="timestep",
    ylabel="IoU",
    labels=
    [r"$n=2,\, N=1,\, t_{\mathcal{A}}=1$"]
    + [
        rf"$n={int(g[0])},\, N={int(g[-1])}$"
        for g in ious.keys()
    ][1:],
    figsize=(7, 6),
    dpi=250,
    start_idx=15,
    stride=8
)

In [None]:
plot_means_variances(
    mean_iou_plot_means,
    xlabel="timestep",
    ylabel="IoU",
    labels=
    [r"$n=2,\, N=1,\, t_{\mathcal{A}}=1$"]
    + [
        rf"$n={int(g[0])},\, N={int(g[-1])}$"
        for g in ious.keys()
    ][1:],
    figsize=(7, 5),
    dpi=250,
    start_idx=15,
    stride=8,
    legend_outside=True,
    legend_pos="right",
    legend_ncol=1
)

In [None]:
plot_means_variances(
    mean_iou_plot_means[:3],
    mean_iou_plot_stds[:3],
    xlabel="timestep",
    ylabel="IoU",
    labels=[
        rf"$n_{{\rm fields}}={int(g[0])},\, n_{{\rm jumps}}={int(g[-1])}$"
        for g in ious.keys()
    ][:3],
    figsize=(12, 8),
    upper_bound=1.0,
    dpi=250,
    start_idx=0,
    stride=8
)

In [None]:
plot_means_variances(
    mean_iou_plot_means[3:6],
    mean_iou_plot_stds[3:6],
    xlabel="timestep",
    ylabel="IoU",
    labels=[
        rf"$n_{{\rm fields}}={int(g[0])},\, n_{{\rm jumps}}={int(g[-1])}$"
        for g in ious.keys()
    ][3:6],
    figsize=(12, 8),
    upper_bound=1.0,
    dpi=250,
    start_idx=0,
    stride=8
)

In [None]:
plot_means_variances(
    mean_iou_plot_means[6:],
    mean_iou_plot_stds[6:],
    xlabel="timestep",
    ylabel="IoU",
    labels=[
        rf"$n_{{\rm fields}}={int(g[0])},\, n_{{\rm jumps}}={int(g[-1])}$"
        for g in ious.keys()
    ][6:],
    figsize=(12, 8),
    upper_bound=1.0,
    dpi=250,
    start_idx=0,
    stride=8
)

In [None]:
plot_means_variances(
    mean_iou_plot_means[2::3],
    mean_iou_plot_stds[2::3],
    xlabel="timestep",
    ylabel="IoU",
    labels=[
        rf"$n_{{\rm fields}}={int(g[0])},\, n_{{\rm jumps}}={int(g[-1])}$"
        for g in ious.keys()
    ][2::3],
    figsize=(10, 8),
    upper_bound=1.0,
    dpi=250,
    start_idx=0,
    stride=8
)

# Look at compartments

In [None]:
n_compartments=64
bounding_box=(-0.9, 1.9, -0.9, 1.9)

In [None]:
a_x, b_x, a_y, b_y = bounding_box

step_x = (b_x - a_x) / n_compartments
step_y = (b_y - a_y) / n_compartments

In [None]:
from matplotlib.patches import Rectangle

def plot_step_grid(
    bounds: tuple[float, float, float, float],
    step_x: float,
    step_y: float,
    ax: plt.Axes | None = None,
    *,
    line_kw: dict | None = None,
    boundary_kw: dict | None = None,
    draw_axes_at_zero: bool = False,
    equal_aspect: bool = True,
    # Highlight options
    highlight: tuple[float, float, float, float] | None = None,
    highlight_units: str = "cells",  # "cells" or "data"
    highlight_kw: dict | None = None,
    clip_highlight_to_bounds: bool = True,
):
    """
    Plot grid lines spaced by step_x (vertical) and step_y (horizontal) within bounds.
    Optionally highlight a rectangle made of multiple compartments.

    Args:
        bounds: (a_x, b_x, a_y, b_y)
        step_x, step_y: spacings (>0)
        ax: optional matplotlib Axes
        line_kw: style for grid lines, e.g. {'color':'0.8','linewidth':1}
        boundary_kw: style for the boundary box
        draw_axes_at_zero: draw x=0 / y=0 if inside bounds
        equal_aspect: set equal aspect ratio

        highlight: tuple specifying rectangle to highlight.
          If highlight_units == "cells": (i, j, w_cells, h_cells), integers recommended.
            - i, j are lower-left cell indices (0 at a_x/a_y).
            - w_cells, h_cells are sizes in number of cells.
          If highlight_units == "data": (x0, y0, width, height) in data coordinates.

        highlight_units: "cells" or "data"
        highlight_kw: patch style for the highlighted rectangle
        clip_highlight_to_bounds: clip the highlight within bounds if it overflows.

    Returns:
        The matplotlib Axes with the grid drawn.
    """
    if step_x <= 0 or step_y <= 0:
        raise ValueError("step_x and step_y must be positive.")
    a_x, b_x, a_y, b_y = bounds
    if a_x > b_x or a_y > b_y:
        raise ValueError("Bounds must satisfy a_x <= b_x and a_y <= b_y.")

    # Defaults
    _line_kw = {'color': '0.8', 'linewidth': 1.0}
    if line_kw:
        _line_kw.update(line_kw)
    _boundary_kw = {'edgecolor': 'k', 'linewidth': 1.5, 'fill': False}
    if boundary_kw:
        _boundary_kw.update(boundary_kw)
    _highlight_kw = {'facecolor': 'tab:orange', 'alpha': 0.25, 'edgecolor': 'tab:orange', 'linewidth': 2.0}
    if highlight_kw:
        _highlight_kw.update(highlight_kw)

    if ax is None:
        fig, ax = plt.subplots(dpi=200)

    # Generate grid lines (include top/right edge; eps avoids FP miss)
    eps_x = step_x * 0.5
    eps_y = step_y * 0.5
    xs = np.arange(a_x, b_x + eps_x, step_x)
    ys = np.arange(a_y, b_y + eps_y, step_y)

    ax.vlines(xs, a_y, b_y, **_line_kw)
    ax.hlines(ys, a_x, b_x, **_line_kw)

    # Outer boundary box
    ax.add_patch(Rectangle((a_x, a_y), b_x - a_x, b_y - a_y, **_boundary_kw))

    # Optional axes at zero
    if draw_axes_at_zero:
        if a_x <= 0 <= b_x:
            ax.axvline(0, color='0.4', linewidth=_line_kw.get('linewidth', 1.0))
        if a_y <= 0 <= b_y:
            ax.axhline(0, color='0.4', linewidth=_line_kw.get('linewidth', 1.0))

    # Highlight rectangle
    if highlight is not None:
        if highlight_units not in {"cells", "data"}:
            raise ValueError("highlight_units must be 'cells' or 'data'.")

        if highlight_units == "cells":
            i, j, w_cells, h_cells = highlight
            # Optionally enforce integers (comment out if you want to allow floats)
            for name, v in zip(("i", "j", "w_cells", "h_cells"), (i, j, w_cells, h_cells)):
                if int(v) != v:
                    raise ValueError(f"{name} must be an integer when highlight_units='cells'.")
            i, j, w_cells, h_cells = map(int, (i, j, w_cells, h_cells))

            # Compute number of whole cells along each axis (floor if range not exact multiples)
            nx = int(np.floor((b_x - a_x) / step_x + 1e-12))
            ny = int(np.floor((b_y - a_y) / step_y + 1e-12))

            if clip_highlight_to_bounds:
                i0 = max(0, min(i, nx))
                j0 = max(0, min(j, ny))
                i1 = max(0, min(i + w_cells, nx))
                j1 = max(0, min(j + h_cells, ny))
                w_cells = max(0, i1 - i0)
                h_cells = max(0, j1 - j0)
                i, j = i0, j0
            else:
                # Basic bounds sanity
                if not (0 <= i < nx and 0 <= j < ny and w_cells > 0 and h_cells > 0 and i + w_cells <= nx and j + h_cells <= ny):
                    raise ValueError("Highlight rectangle in cells is out of bounds.")

            x0 = a_x + i * step_x
            y0 = a_y + j * step_y
            w = w_cells * step_x
            h = h_cells * step_y

        else:  # data units
            x0, y0, w, h = highlight
            if clip_highlight_to_bounds:
                # Clip to bounds
                x1 = np.clip(x0 + w, a_x, b_x)
                y1 = np.clip(y0 + h, a_y, b_y)
                x0 = np.clip(x0, a_x, b_x)
                y0 = np.clip(y0, a_y, b_y)
                w = max(0.0, x1 - x0)
                h = max(0.0, y1 - y0)
            else:
                if w <= 0 or h <= 0:
                    raise ValueError("Highlight width/height must be positive in data units.")

        if w > 0 and h > 0:
            ax.add_patch(Rectangle((x0, y0), w, h, zorder=3, **_highlight_kw))

    ax.set_xlim(a_x, b_x)
    ax.set_ylim(a_y, b_y)
    if equal_aspect:
        ax.set_aspect('equal', adjustable='box')

    return ax

In [None]:
plot_step_grid(
    bounding_box,
    step_x,
    step_y,
    highlight=(0.1, 0.1, 0.8, 0.8),
    highlight_units="data",
)

# n_fields = 16

## n_jumps = 4

In [None]:
cfg          = OmegaConf.load(
    Path("logs/train/runs/2025-09-09_07-42-20/.hydra/config.yaml")
)

# factories
latent_model_factory = instantiate(cfg.model.latent_model, _partial_=True)
model_factory        = instantiate(cfg.model,            _partial_=True)
loss_function        = instantiate(cfg.model.loss_function)

lit_model = PhysicsLitModule.load_from_checkpoint(
    Path("logs/train/runs/2025-09-09_07-42-20/waterdrop_physics/s54sjmpc/checkpoints/epoch=28-step=400635.ckpt"),
    latent_model = latent_model_factory,
    model        = model_factory,
    loss_function= loss_function,
    strict       = True,
    map_location = "cuda"
)
lit_model.eval()
dataset_PH = instantiate(cfg.data)
dataset_PH.setup(stage="physics")

dataset_PH.test_dataset.rollout = True

In [None]:
TRAJ_IDX = list(range(dataset_PH.get_dataset(split="test").n_traj))
n_per_traj_PH = dataset_PH.get_dataset(split="test").n_per_traj
IDX_PH = [[0]] * len(TRAJ_IDX)

In [None]:
rollout_physics, fields, GT_physics, GT_vel_fields_normalized_physics, MSE_fields_normalized_physics, MSE_Field_latent_physics, propagated_latents = lit_model.rollout_physics(
    particle_dm=dataset_PH,
    traj_idx=TRAJ_IDX,
    idx=IDX_PH,
    query_gt_pos=False,
    split="test",
    use_gt_field=False,
)

In [None]:
ious = []
for (GT, rollout) in tqdm(zip(GT_physics, rollout_physics), total=len(GT_physics)):
    iou_traj = []
    for (GT_pos, rollout_pos) in zip(GT, rollout):
        # Calculate IoU for each pair of GT and rollout positions
        iou_traj.append(mean_iou(rollout_pos.unsqueeze(0), GT_pos.unsqueeze(0), n_compartments=64, bounding_box=(-0.9, 1.9, -0.9, 1.9)))
    ious.append(torch.stack(iou_traj))

In [None]:
ious = torch.stack(ious, dim=0)

In [None]:
ious.shape

In [None]:
ious_mean_16_4 = ious.mean(axis=0)
ious_std_16_4 = ious.std(axis=0)
print(ious_mean_16_4.shape)

In [None]:
plot_means_variances(
    ious_mean_16_4.numpy(force=True),
    ious_std_16_4.numpy(force=True),
    xlabel="Rollout Timestep",
    ylabel="IoU",
    upper_bound=1.0,
)
plt.show()

In [None]:
ious_mean_16_4.mean()

In [None]:
ious_mean_8_4.mean()

In [None]:
mean_rollout_iou_16_4 = ious_mean_16_4.mean()

In [None]:
# torch.save(ious_mean_16_4, "D:/Projects/Master/UPT/data/ious/ious_mean_16_4.pth")
# torch.save(ious_std_16_4, "D:/Projects/Master/UPT/data/ious/ious_std_16_4.pth")
torch.save(mean_rollout_iou_16_4, "D:/Projects/Master/UPT/data/ious/mean_rollout_iou_16_4.pth")

In [None]:
plot_means_variances(
    [ious_mean_8_4[2::2], ious_mean_16_4],
    [ious_std_8_4[2::2], ious_std_16_4],
    xlabel="timestep",
    ylabel="IoU",
    labels=
    [r"$n=8,\, N=4$", r"$n=16,\, N=4$"],
    figsize=(10, 4),
    dpi=250,
    start_idx=31,
    stride=16
)

# Pushforward
## 8_2

In [None]:
cfg          = OmegaConf.load(
    Path("logs/train/runs/2025-09-11_15-36-08/.hydra/config.yaml")
)

# factories
latent_model_factory = instantiate(cfg.model.latent_model, _partial_=True)
model_factory        = instantiate(cfg.model,            _partial_=True)
loss_function        = instantiate(cfg.model.loss_function)

lit_model = PhysicsLitModule.load_from_checkpoint(
    Path("logs/train/runs/2025-09-11_15-36-08/waterdrop_physics/40ts13d5/checkpoints/epoch=26-step=395685.ckpt"),
    latent_model = latent_model_factory,
    model        = model_factory,
    loss_function= loss_function,
    strict       = True,
    map_location = "cuda"
)
lit_model.eval()
dataset_PH = instantiate(cfg.data)
dataset_PH.setup(stage="physics")

dataset_PH.test_dataset.rollout = True

In [None]:
TRAJ_IDX = list(range(dataset_PH.get_dataset(split="test").n_traj))
n_per_traj_PH = dataset_PH.get_dataset(split="test").n_per_traj
IDX_PH = [[0]] * len(TRAJ_IDX)

In [None]:
rollout_physics, fields, GT_physics, GT_vel_fields_normalized_physics, MSE_fields_normalized_physics, MSE_Field_latent_physics, propagated_latents = lit_model.rollout_physics(
    particle_dm=dataset_PH,
    traj_idx=TRAJ_IDX,
    idx=IDX_PH,
    query_gt_pos=False,
    split="test",
    use_gt_field=False,
)

In [None]:
ious = []
for (GT, rollout) in tqdm(zip(GT_physics, rollout_physics), total=len(GT_physics)):
    iou_traj = []
    for (GT_pos, rollout_pos) in zip(GT, rollout):
        # Calculate IoU for each pair of GT and rollout positions
        iou_traj.append(mean_iou(rollout_pos.unsqueeze(0), GT_pos.unsqueeze(0), n_compartments=64, bounding_box=(-0.9, 1.9, -0.9, 1.9)))
    ious.append(torch.stack(iou_traj))

In [None]:
ious = torch.stack(ious, dim=0)

In [None]:
ious.shape

In [None]:
ious_mean_8_2_pushforward = ious.mean(axis=0)
ious_std_8_2_pushforward = ious.std(axis=0)

In [None]:
plot_means_variances(
    [ious_mean_8_2, ious_mean_8_2_pushforward],
    [ious_std_8_2, ious_std_8_2_pushforward],
    xlabel="timestep",
    ylabel="IoU",
    labels=
    ["normal", "pushforward"],
    figsize=(10, 4),
    dpi=250,
    start_idx=15,
    stride=8
)

In [None]:
print(ious_mean_8_2_pushforward.mean())
print(ious_mean_8_2.mean())

## 8_4 pushforward

In [None]:
cfg          = OmegaConf.load(
    Path("logs/train/runs/2025-09-13_03-52-00/.hydra/config.yaml")
)

# factories
latent_model_factory = instantiate(cfg.model.latent_model, _partial_=True)
model_factory        = instantiate(cfg.model,            _partial_=True)
loss_function        = instantiate(cfg.model.loss_function)

lit_model = PhysicsLitModule.load_from_checkpoint(
    Path("logs/train/runs/2025-09-13_03-52-00/waterdrop_physics/4zj1z3xo/checkpoints/epoch=25-step=374790.ckpt"),
    latent_model = latent_model_factory,
    model        = model_factory,
    loss_function= loss_function,
    strict       = True,
    map_location = "cuda"
)
lit_model.eval()
dataset_PH = instantiate(cfg.data)
dataset_PH.setup(stage="physics")

dataset_PH.test_dataset.rollout = True

In [None]:
rollout_physics, fields, GT_physics, GT_vel_fields_normalized_physics, MSE_fields_normalized_physics, MSE_Field_latent_physics, propagated_latents = lit_model.rollout_physics(
    particle_dm=dataset_PH,
    traj_idx=TRAJ_IDX,
    idx=IDX_PH,
    query_gt_pos=False,
    split="test",
    use_gt_field=False,
)

In [None]:
ious = []
for (GT, rollout) in tqdm(zip(GT_physics, rollout_physics), total=len(GT_physics)):
    iou_traj = []
    for (GT_pos, rollout_pos) in zip(GT, rollout):
        # Calculate IoU for each pair of GT and rollout positions
        iou_traj.append(mean_iou(rollout_pos.unsqueeze(0), GT_pos.unsqueeze(0), n_compartments=64, bounding_box=(-0.9, 1.9, -0.9, 1.9)))
    ious.append(torch.stack(iou_traj))

In [None]:
ious = torch.stack(ious, dim=0)

In [None]:
ious_mean_8_4_pushforward = ious.mean(axis=0)
ious_std_8_4_pushforward = ious.std(axis=0)

In [None]:
print(ious_mean_8_2_pushforward.mean())
print(ious_mean_8_2.mean())
print(ious_mean_8_4_pushforward.mean())
print(ious_mean_8_4.mean())

# Displacement

In [None]:
cfg          = OmegaConf.load(
    Path("logs/train/runs/2025-09-10_23-31-52/.hydra/config.yaml")
)

# factories
latent_model_factory = instantiate(cfg.model.latent_model, _partial_=True)
model_factory        = instantiate(cfg.model,            _partial_=True)
loss_function        = instantiate(cfg.model.loss_function)

lit_model = PhysicsLitModule.load_from_checkpoint(
    Path("logs/train/runs/2025-09-10_23-31-52/waterdrop_physics/ft1skpra/checkpoints/epoch=24-step=360375.ckpt"),
    latent_model = latent_model_factory,
    model        = model_factory,
    loss_function= loss_function,
    strict       = True,
    map_location = "cuda"
)
lit_model.eval()
dataset_PH = instantiate(cfg.data)
dataset_PH.setup(stage="physics")

dataset_PH.test_dataset.rollout = True

In [None]:
TRAJ_IDX = list(range(dataset_PH.get_dataset(split="test").n_traj))
n_per_traj_PH = dataset_PH.get_dataset(split="test").n_per_traj
IDX_PH = [[0]] * len(TRAJ_IDX)

In [None]:
rollout_physics, fields, GT_physics, GT_vel_fields_normalized_physics, MSE_fields_normalized_physics, MSE_Field_latent_physics, propagated_latents = lit_model.rollout_physics(
    particle_dm=dataset_PH,
    traj_idx=TRAJ_IDX,
    idx=IDX_PH,
    query_gt_pos=False,
    split="test",
    use_gt_field=False,
)

In [None]:
ious = []
for (GT, rollout) in tqdm(zip(GT_physics, rollout_physics), total=len(GT_physics)):
    iou_traj = []
    for (GT_pos, rollout_pos) in zip(GT, rollout):
        # Calculate IoU for each pair of GT and rollout positions
        iou_traj.append(mean_iou(rollout_pos.unsqueeze(0), GT_pos.unsqueeze(0), n_compartments=64, bounding_box=(-0.9, 1.9, -0.9, 1.9)))
    ious.append(torch.stack(iou_traj))

In [None]:
ious = torch.stack(ious, dim=0)

In [None]:
ious.shape

In [None]:
ious_mean_8_4_displacement = ious.mean(axis=0)
ious_std_8_4_displacement = ious.std(axis=0)
print(ious_mean_8_4_displacement.shape)

In [None]:
plot_means_variances(
    ious_mean_8_4_displacement.numpy(force=True),
    ious_std_8_4_displacement.numpy(force=True),
    xlabel="Rollout Timestep",
    ylabel="IoU",
    upper_bound=1.0,
)
plt.show()

In [None]:
plot_means_variances(
    [ious_mean_8_4, ious_mean_8_4_displacement],
    [ious_std_8_4, ious_std_8_4_displacement],
    xlabel="timestep",
    ylabel="IoU",
    labels=
    ["velocity", "displacement"],
    figsize=(10, 4),
    dpi=250,
    start_idx=15,
    stride=8
)

In [None]:
print(ious_mean_8_4_displacement.mean())
print(ious_mean_8_4.mean())

In [None]:
ani = animate(
    rollout_physics[0],
    ground_truth=GT_physics[0],
    ref_frame=((0, 1), (0, 1)),
    start_idx=16,
    n_skip_ahead_timesteps=8
)
HTML(ani.to_jshtml())

In [None]:
ani = animate(
    rollout_physics[1],
    ground_truth=GT_physics[1],
    ref_frame=((0, 1), (0, 1)),
    start_idx=16,
    n_skip_ahead_timesteps=8
)
HTML(ani.to_jshtml())

In [None]:
ani = animate(
    rollout_physics[2],
    ground_truth=GT_physics[2],
    ref_frame=((0, 1), (0, 1)),
    start_idx=16,
    n_skip_ahead_timesteps=8
)
HTML(ani.to_jshtml())

In [None]:
ani = animate(
    rollout_physics[3],
    ground_truth=GT_physics[3],
    ref_frame=((0, 1), (0, 1)),
    start_idx=16,
    n_skip_ahead_timesteps=8
)
HTML(ani.to_jshtml())

In [None]:
ani = animate(
    rollout_physics[4],
    ground_truth=GT_physics[4],
    ref_frame=((0, 1), (0, 1)),
    start_idx=16,
    n_skip_ahead_timesteps=8
)
HTML(ani.to_jshtml())

# Model size

In [None]:
cfg          = OmegaConf.load(
    Path("logs/train/runs/2025-09-12_18-29-07/.hydra/config.yaml")
)

# factories
latent_model_factory = instantiate(cfg.model.latent_model, _partial_=True)
model_factory        = instantiate(cfg.model,            _partial_=True)
loss_function        = instantiate(cfg.model.loss_function)

lit_model = PhysicsLitModule.load_from_checkpoint(
    Path("logs/train/runs/2025-09-12_18-29-07/waterdrop_physics/9odrn4lu/checkpoints/epoch=27-step=403620.ckpt"),
    latent_model = latent_model_factory,
    model        = model_factory,
    loss_function= loss_function,
    strict       = True,
    map_location = "cuda"
)
lit_model.eval()
dataset_PH = instantiate(cfg.data)
dataset_PH.setup(stage="physics")

dataset_PH.test_dataset.rollout = True

In [None]:
TRAJ_IDX = list(range(dataset_PH.get_dataset(split="test").n_traj))
n_per_traj_PH = dataset_PH.get_dataset(split="test").n_per_traj
IDX_PH = [[0]] * len(TRAJ_IDX)

In [None]:
rollout_physics, fields, GT_physics, GT_vel_fields_normalized_physics, MSE_fields_normalized_physics, MSE_Field_latent_physics, propagated_latents = lit_model.rollout_physics(
    particle_dm=dataset_PH,
    traj_idx=TRAJ_IDX,
    idx=IDX_PH,
    query_gt_pos=False,
    split="test",
    use_gt_field=False,
)

In [None]:
ious = []
for (GT, rollout) in tqdm(zip(GT_physics, rollout_physics), total=len(GT_physics)):
    iou_traj = []
    for (GT_pos, rollout_pos) in zip(GT, rollout):
        # Calculate IoU for each pair of GT and rollout positions
        iou_traj.append(mean_iou(rollout_pos.unsqueeze(0), GT_pos.unsqueeze(0), n_compartments=64, bounding_box=(-0.9, 1.9, -0.9, 1.9)))
    ious.append(torch.stack(iou_traj))

In [None]:
ious = torch.stack(ious, dim=0)

In [None]:
ious.shape

In [None]:
ious_mean_8_4_tiny = ious.mean(axis=0)
ious_std_8_4_tiny = ious.std(axis=0)
print(ious_mean_8_4_tiny.shape)

In [None]:
plot_means_variances(
    ious_mean_8_4_tiny,
    ious_std_8_4_tiny,
    xlabel="Rollout Timestep",
    ylabel="IoU",
    upper_bound=1.0,
)
plt.show()

In [None]:
plot_means_variances(
    [ious_mean_8_4, ious_mean_8_4_tiny],
    [ious_std_8_4, ious_std_8_4_tiny],
    xlabel="timestep",
    ylabel="IoU",
    labels=
    ["baseline", "tiny"],
    figsize=(10, 4),
    dpi=250,
    start_idx=15,
    stride=8
)

In [None]:
import matplotlib.ticker as mtick

x = np.array([0, 1], dtype=float)
y = np.array([ious_mean_8_4_tiny.mean(), ious_mean_8_4.mean()], dtype=float)
params = np.array([num_params_tiny, num_params], dtype=float)

# Bubble area scaling (points^2)
target_max_area = 1500.0
scale = params.max() / target_max_area
sizes = params / scale

fig, ax = plt.subplots(figsize=(7.5, 5), dpi=200)

# Scatter as true circles
ax.scatter(x, y, s=sizes, marker='o', alpha=0.75)

# Axes formatting
ax.set_xticks(x)
ax.set_xticklabels([f"Tiny (~{num_params_tiny/1e6:.1f}M)", f"Big (~{num_params/1e6:.1f}M)"])
ax.set_ylabel("Mean IoU")
ax.set_title("Model Size vs. Performance (mIoU)")
ax.grid(True, axis="y", linestyle="--", linewidth=0.6, alpha=0.7)
ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.3f'))
ax.set_ylim(min(y)-0.02, max(y)+0.02)

# Annotate values
for xi, yi in zip(x, y):
    ax.annotate(f"{yi:.3f}", (xi, yi), xytext=(0, 10),
                textcoords="offset points", ha="center", va="bottom")

# ---- Legend: scale bubbles to fit neatly ----
legend_params = np.array([5e6, 25e6])
# Desired max area inside legend (points^2)
legend_max_area = 400.0
s_max = sizes.max()  # equals target_max_area
legend_scale = np.sqrt(legend_max_area / s_max)
legend_sizes = (legend_params / scale) * (legend_scale**2)
legend_labels = [f"{p/1e6:.0f}M params" for p in legend_params]

legend_handles = [ax.scatter([], [], s=s, marker='o', alpha=0.75, label=lab)
                  for s, lab in zip(legend_sizes, legend_labels)]

ax.legend(handles=legend_handles, title="Parameters (circle area)",
          frameon=True, loc="lower right", borderpad=0.8, handletextpad=0.8)

plt.tight_layout()

In [None]:
print(ious_mean_8_4_tiny.mean())
print(ious_mean_8_4.mean())

In [None]:
num_params_tiny = sum(p.numel() for p in lit_model.parameters())
num_params_tiny_latent = sum(p.numel() for p in lit_model.latent_model.parameters())
num_params_tiny_AE = num_params_tiny - num_params_tiny_latent

In [None]:
print(num_params)
print(num_params_latent)
print(num_params_AE)

In [None]:
print(num_params_tiny)
print(num_params_tiny_latent)
print(num_params_tiny_AE)

In [None]:
ani = animate(
    rollout_physics[0],
    ground_truth=GT_physics[0],
    ref_frame=((0, 1), (0, 1)),
    start_idx=16,
    n_skip_ahead_timesteps=8
)
HTML(ani.to_jshtml())

In [None]:
ani = animate(
    rollout_physics[1],
    ground_truth=GT_physics[1],
    ref_frame=((0, 1), (0, 1)),
    start_idx=16,
    n_skip_ahead_timesteps=8
)
HTML(ani.to_jshtml())

In [None]:
ani = animate(
    rollout_physics[2],
    ground_truth=GT_physics[2],
    ref_frame=((0, 1), (0, 1)),
    start_idx=16,
    n_skip_ahead_timesteps=8
)
HTML(ani.to_jshtml())

In [None]:
x = np.array([0.7183, 1.4725], dtype=float)
xerr = np.array([0.1796, 0.2687], dtype=float)
y = np.array([ious_mean_8_4_tiny.mean(), ious_mean_8_4.mean()], dtype=float)
params = np.array([num_params_tiny, num_params], dtype=float)
labels = [f"Tiny (~{num_params_tiny/1e6:.1f}M)", f"Baseline (~{num_params/1e6:.1f}M)"]

# Bubble sizing
target_max_area = 1500.0
scale = params.max() / target_max_area
sizes = params / scale

fig, ax = plt.subplots(figsize=(7.5, 5), dpi=200)

# Horizontal error bars with T-shaped caps
ax.errorbar(
    x, y, xerr=xerr,
    fmt='none',
    capsize=7,        # length of the cap "handles"
    capthick=1.2,     # thickness of the cap lines
    elinewidth=1.2,   # thickness of the error bar line
    alpha=0.9,
    zorder=1
)

# Circles
ax.scatter(x, y, s=sizes, marker='o', alpha=0.8, zorder=2)

# Annotations
for xi, yi, lab in zip(x, y, labels):
    ax.annotate(lab, (xi, yi), xytext=(0, -14), textcoords="offset points", ha="center", va="top")
    ax.annotate(f"mIoU={yi:.3f}", (xi, yi), xytext=(0, 12), textcoords="offset points", ha="center", va="bottom")

# Axes
ax.set_xlabel("Rollout time (s) — error bars show ±1 SD")
ax.set_ylabel("Mean IoU")
ax.set_title("Speed–Accuracy–Size Trade-off")

ax.set_ylim(min(y) - 0.02, max(y) + 0.02)
ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.3f'))

xmin = (x - xerr).min()
xmax = (x + xerr).max()
pad = 0.06 * (xmax - xmin)
ax.set_xlim(xmin - pad, xmax + pad)

ax.grid(True, axis="both", linestyle="--", linewidth=0.6, alpha=0.7)

# Legend for parameter sizes (scaled for legend box)
legend_params = np.array([5e6, 25e6])
legend_max_area = 400.0
s_max = sizes.max()
legend_scale = np.sqrt(legend_max_area / s_max)
legend_sizes = (legend_params / scale) * (legend_scale**2)
legend_labels = [f"{p/1e6:.0f}M params" for p in legend_params]
legend_handles = [ax.scatter([], [], s=s, marker='o', alpha=0.8, label=lab)
                  for s, lab in zip(legend_sizes, legend_labels)]
ax.legend(handles=legend_handles, title="Parameters (circle area)",
          frameon=True, loc="lower right", borderpad=0.8, handletextpad=0.8)

plt.tight_layout()

# Acceleration field

In [None]:
cfg          = OmegaConf.load(
    Path("logs/train/runs/2025-09-16_22-54-49/.hydra/config.yaml")
)

# factories
latent_model_factory = instantiate(cfg.model.latent_model, _partial_=True)
model_factory        = instantiate(cfg.model,            _partial_=True)
loss_function        = instantiate(cfg.model.loss_function)

lit_model = PhysicsLitModule.load_from_checkpoint(
    Path("logs/train/runs/2025-09-16_22-54-49/waterdrop_physics/h436l13d/checkpoints/epoch=24-step=360000.ckpt"),
    latent_model = latent_model_factory,
    model        = model_factory,
    loss_function= loss_function,
    strict       = True,
    map_location = "cuda"
)
lit_model.eval()
dataset_PH = instantiate(cfg.data)
dataset_PH.setup(stage="physics")

dataset_PH.test_dataset.rollout = True

In [None]:
TRAJ_IDX = list(range(dataset_PH.get_dataset(split="test").n_traj))
n_per_traj_PH = dataset_PH.get_dataset(split="test").n_per_traj
IDX_PH = [[0]] * len(TRAJ_IDX)

In [None]:
rollout_physics, fields, GT_physics, GT_vel_fields_normalized_physics, MSE_fields_normalized_physics, MSE_Field_latent_physics, propagated_latents = lit_model.rollout_physics(
    particle_dm=dataset_PH,
    traj_idx=TRAJ_IDX,
    idx=IDX_PH,
    query_gt_pos=False,
    split="test",
    use_gt_field=False,
)

In [None]:
ious = []
for (GT, rollout) in tqdm(zip(GT_physics, rollout_physics), total=len(GT_physics)):
    iou_traj = []
    for (GT_pos, rollout_pos) in zip(GT, rollout):
        # Calculate IoU for each pair of GT and rollout positions
        iou_traj.append(mean_iou(rollout_pos.unsqueeze(0), GT_pos.unsqueeze(0), n_compartments=64, bounding_box=(-0.9, 1.9, -0.9, 1.9)))
    ious.append(torch.stack(iou_traj))

In [None]:
ious = torch.stack(ious, dim=0)

In [None]:
ious.shape

In [None]:
ious_mean_8_4_acceleration = ious.mean(axis=0)
ious_std_8_4_acceleration = ious.std(axis=0)
print(ious_mean_8_4_acceleration.shape)

In [None]:
plot_means_variances(
    ious_mean_8_4_acceleration,
    ious_std_8_4_acceleration,
    xlabel="Rollout Timestep",
    ylabel="IoU",
    upper_bound=1.0,
)
plt.show()

In [None]:
print(ious_mean_8_4_acceleration.mean())

In [None]:
ani = animate(
    rollout_physics[1],
    ground_truth=GT_physics[1],
    ref_frame=((0, 1), (0, 1)),
    start_idx=16,
    n_skip_ahead_timesteps=8
)
HTML(ani.to_jshtml())

In [None]:
ani = animate(
    rollout_physics[2],
    ground_truth=GT_physics[2],
    ref_frame=((0, 1), (0, 1)),
    start_idx=16,
    n_skip_ahead_timesteps=8
)
HTML(ani.to_jshtml())

In [None]:
plot_means_variances(
    [ious_mean_8_4, ious_mean_8_4_acceleration],
    [ious_std_8_4, ious_std_8_4_acceleration],
    xlabel="timestep",
    ylabel="IoU",
    labels=
    ["velocity", "acceleration"],
    figsize=(10, 4),
    dpi=250,
    start_idx=15,
    stride=8
)

In [None]:
len(ious_mean_8_4_acceleration)