In [None]:
from omegaconf import OmegaConf
from hydra.core.global_hydra import GlobalHydra
from hydra.utils import instantiate
from src.models.autoencoder import AutoencoderLitModule
from src.models.physics import PhysicsLitModule
from src.utils import animate

import os
import sys
from IPython.display import HTML
from tqdm import tqdm
import torch
from torch_geometric.data import Data
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation as animation
matplotlib.rcParams['animation.embed_limit'] = 400  # set 100 MB limit for animations

import numpy as np
import einops
from functools import partial
from src.datasets.particle_datamodule import ParticleDataModule

GlobalHydra.instance().clear()

os.environ["PROJECT_ROOT"] = os.path.abspath(".")

In [None]:
def plot_metric(mse, title: str, log: bool = False):
    plt.clf()
    fig, ax = plt.subplots(figsize=(10, 5))
    ax.plot(mse, label='Metric')
    ax.set_xlabel('Timestep')
    ax.set_ylabel('Metric')
    ax.set_title(f"{title} Metric Over Time")
    ax.legend()
    if log: ax.set_yscale('log')
    plt.tight_layout()
    plt.show()

def plot_two_series(y1: np.ndarray,
                    y2: np.ndarray,
                    labels=('Series 1', 'Series 2'),
                    xlabel: str = 'Index',
                    ylabel: str = 'Value',
                    title: str = 'Two Series on One Plot',
                    log: bool = False,
                    y_limit: tuple = None,
                    figsize: tuple = None,
                    dpi: int = 100,
                    gridlines: tuple[bool] = (False, False)) -> None:
    """
    Plots two same-length sequences/arrays on the same axes.

    Parameters
    ----------
    y1 : np.ndarray
        First data series (plotted in default style).
    y2 : np.ndarray
        Second data series (plotted in dashed style).
    labels : tuple of str, optional
        Labels for the two series (default ('Series 1', 'Series 2')).
    xlabel : str, optional
        Label for the x-axis (default 'Index').
    ylabel : str, optional
        Label for the y-axis (default 'Value').
    title : str, optional
        Title of the plot (default 'Two Series on One Plot').

    Raises
    ------
    ValueError
        If y1 and y2 are not the same length.
    """
    if len(y1) != len(y2):
        raise ValueError(f"Input arrays must have the same length; got {len(y1)} and {len(y2)}")

    x = np.arange(len(y1))              # common x-axis
    fig, ax = plt.subplots(figsize=figsize, dpi=dpi)  # new figure
    ax.plot(x, y1, label=labels[0])    # first series
    ax.plot(x, y2, label=labels[1])  # second series with dashed line
    ax.set_xlabel(xlabel)                  # x-axis label
    ax.set_ylabel(ylabel)                  # y-axis label
    ax.set_title(title)                    # plot title
    ax.legend()                        # show legend
    ax.xaxis.grid(gridlines[0])  # x-axis gridlines if requested
    ax.yaxis.grid(gridlines[1])  # y-axis gridlines if requested
    if log: ax.set_yscale('log')           # set y-axis to logarithmic scale if requested
    if y_limit: ax.set_ylim(y_limit)       # set y-axis limits if provided
    plt.tight_layout()                  # nicely fit elements
    plt.show()                          # display

def compute_mse(preds, targets):
    preds = einops.rearrange(preds, 't n c -> t (n c)')
    targets = einops.rearrange(targets, 't n c -> t (n c)')
    mse = torch.mean((preds - targets) ** 2, dim=1).cpu().numpy()
    return mse

In [None]:
cfg = OmegaConf.load("logs/train/runs/2025-09-02_16-10-21/.hydra/config.yaml")
model = instantiate(cfg.model)
net = instantiate(cfg.model.model)
loss_function = instantiate(cfg.model.loss_function)
model = AutoencoderLitModule.load_from_checkpoint(
    checkpoint_path="logs/train/runs/2025-09-02_16-10-21/waterdrop/pu1byyxz/checkpoints/epoch=26-step=404595.ckpt",
    model=net,
    loss_function=loss_function
)
model.eval()
model.to("cuda")
dataset_AE = instantiate(cfg.data)
dataset_AE.setup(stage="autoencoder")
dataset_AE.shuffle = False
dataset_AE.batch_size = 1
dataset_AE.num_workers = 0
dataset_AE.pin_memory = False
dataset_AE.persistent_workers = False
dataset_AE.train_dataset.rollout = True
dataset_AE.val_dataset.rollout = True

In [None]:
cfg          = OmegaConf.load(
    "logs/train/runs/2025-09-04_07-13-04/.hydra/config.yaml"
)

# factories
latent_model_factory = instantiate(cfg.model.latent_model, _partial_=True)
model_factory        = instantiate(cfg.model,            _partial_=True)
loss_function        = instantiate(cfg.model.loss_function)

lit_model = PhysicsLitModule.load_from_checkpoint(
    "logs/train/runs/2025-09-04_07-13-04/waterdrop_physics/w55pl0j7/checkpoints/epoch=15-step=239520.ckpt",
    latent_model = latent_model_factory,
    model        = model_factory,
    loss_function= loss_function,
    strict       = True,
    map_location = "cuda"
)
lit_model.eval()
dataset_PH = instantiate(cfg.data)
dataset_PH.setup(stage="physics")

In [None]:
TRAJ_IDX = list(range(10))
IDX_PH = [[0]]*10
IDX_AE = [[0]]*10

# AE

In [None]:
rollouts, GT_positions, GT_vel_fields_normalized, MSE_Fields_normalized, latents = model.GT_encode_decode(
    dataset_AE,
    traj_idx=TRAJ_IDX,
    idx=IDX_AE,
    query_gt_pos=False,
    split="test"
)

In [None]:
ani = animate(
    rollouts[0],
    ground_truth=GT_positions[0],
    ref_frame=((0, 1), (0, 1))
)

In [None]:
HTML(ani.to_jshtml())

In [None]:
rollouts_AE, GT_positions_AE, GT_vel_fields_normalized_AE, MSE_Fields_normalized_AE, latents_AE = model.GT_encode_decode(
    dataset_AE,
    traj_idx=TRAJ_IDX,
    idx=IDX_AE,
    query_gt_pos=True,
    split="test"
)

In [None]:
plot_two_series(
    MSE_Fields_normalized[0][:200],
    MSE_Fields_normalized_AE[0][:200],
    labels=('AE-int', 'AE'),
    xlabel='Timestep',
    ylabel='MSE on particle positions',
    title='Particle position error between AE and AE-int',
    log=True,
    y_limit=(5*1e-7, 1e0),
    figsize=(8, 5),
    dpi=200,
    gridlines=(True, False)
)

# Physics

In [None]:
rollout_AE, GT_AE, GT_vel_fields_normalized_AE, MSE_fields_normalized_AE, latents_AE = model.GT_encode_decode(
    particle_dm=dataset_AE,
    traj_idx=TRAJ_IDX,
    idx=IDX_AE,
    query_gt_pos=True,
    split="test",
)

In [None]:
rollout_AE_int, GT_AE_int, GT_vel_fields_normalized_AE_int, MSE_fields_normalized_AE_int, latents_AE_int = model.GT_encode_decode(
    particle_dm=dataset_AE,
    traj_idx=TRAJ_IDX,
    idx=IDX_AE,
    query_gt_pos=False,
    split="test",
)

In [None]:
for (MSE_field_normalized_AE, MSE_field_normalized_physics) in zip(MSE_fields_normalized_AE, MSE_fields_normalized_physics):
    plot_two_series(
        MSE_field_normalized_AE[1:],
        MSE_field_normalized_physics,
        labels=('Reconstruction MSE', 'One-step oracle MSE'),
        xlabel='Timestep',
        ylabel='MSE',
        title='MSE Comparison Between AE and Physics Models',
        log=True
    )
    plt.clf()

## 0st

In [None]:
ani = animate(
    rollout=rollout_AE_int[0],
    ground_truth=GT_AE_int[0],
)
HTML(ani.to_jshtml())

In [None]:
rollout_physics_0_50, GT_physics_0_50, GT_vel_fields_normalized_physics_0_50, MSE_fields_normalized_physics_0_50, MSE_Field_latent_physics_0_50, latents_physics_0_50 = lit_model.rollout_physics(
    particle_dm=dataset_PH,
    traj_idx=[0],
    idx=[[50]],
    query_gt_pos=False,
    split="test",
    use_gt_field=False
)

In [None]:
ani = animate(
    rollout=rollout_physics_0_50,
    ground_truth=GT_physics_0_50,
    start_idx=50,
)
HTML(ani.to_jshtml())

In [None]:
rollout_physics_0_0, GT_physics_0_0, GT_vel_fields_normalized_physics_0_0, MSE_fields_normalized_physics_0_0, MSE_Field_latent_physics_0_0, latents_physics_0_0 = lit_model.rollout_physics(
    particle_dm=dataset_PH,
    traj_idx=[0],
    idx=[[0]],
    query_gt_pos=False,
    split="test",
    use_gt_field=False
)

In [None]:
ani = animate(
    rollout=rollout_physics_0_0,
    ground_truth=GT_physics_0_0,
    ref_frame=((0, 1), (0, 0.6))
)
HTML(ani.to_jshtml())

## Trajectories starting at 250

In [None]:
TRAJ_IDX = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
IDX_PH = [[250]]*10

In [None]:
rollout_physics_250, GT_physics_250, GT_vel_fields_normalized_physics_250, MSE_fields_normalized_physics_250, MSE_Field_latent_physics_250, latents_physics_250 = lit_model.rollout_physics(
    particle_dm=dataset_PH,
    traj_idx=TRAJ_IDX,
    idx=IDX_PH,
    query_gt_pos=False,
    split="test",
    use_gt_field=False
)

In [None]:
ani = animate(
    rollout=rollout_physics_250[1],
    ground_truth=GT_physics_250[1],
    ref_frame=((0, 1), (0, 1)),
    start_idx=250
)
HTML(ani.to_jshtml())

In [None]:
ani = animate(
    rollout=rollout_physics_250[6],
    ground_truth=GT_physics_250[6],
    ref_frame=((0, 1), (0, 1)),
    start_idx=250
)
HTML(ani.to_jshtml())

In [None]:
ani = animate(
    rollout=rollout_physics_250[8],
    ground_truth=GT_physics_250[8],
    ref_frame=((0, 1), (0, 1)),
    start_idx=250
)
HTML(ani.to_jshtml())