In [None]:
from omegaconf import OmegaConf
from hydra.core.global_hydra import GlobalHydra
from hydra.utils import instantiate
from src.models.autoencoder import AutoencoderLitModule
from src.models.physics import PhysicsLitModule
from src.utils import animate

import os
import sys
from IPython.display import HTML
from tqdm import tqdm
import torch
from torch_geometric.data import Data
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation as animation
matplotlib.rcParams['animation.embed_limit'] = 400  # set 100 MB limit for animations

import numpy as np
import einops
from functools import partial
from src.datasets.particle_datamodule import ParticleDataModule
from src.callbacks.metrics import MetricsCallback

GlobalHydra.instance().clear()

os.environ["PROJECT_ROOT"] = os.path.abspath(".")

In [None]:
def plot_metric(mse, title: str, log: bool = False):
    plt.clf()
    fig, ax = plt.subplots(figsize=(10, 5))
    ax.plot(mse, label='Metric')
    ax.set_xlabel('Timestep')
    ax.set_ylabel('Metric')
    ax.set_title(f"{title} Metric Over Time")
    ax.legend()
    if log: ax.set_yscale('log')
    plt.tight_layout()
    plt.show()

def plot_two_series(y1: np.ndarray,
                    y2: np.ndarray,
                    labels=('Series 1', 'Series 2'),
                    xlabel: str = 'Index',
                    ylabel: str = 'Value',
                    title: str = 'Two Series on One Plot',
                    log: bool = False) -> None:
    """
    Plots two same-length sequences/arrays on the same axes.

    Parameters
    ----------
    y1 : np.ndarray
        First data series (plotted in default style).
    y2 : np.ndarray
        Second data series (plotted in dashed style).
    labels : tuple of str, optional
        Labels for the two series (default ('Series 1', 'Series 2')).
    xlabel : str, optional
        Label for the x-axis (default 'Index').
    ylabel : str, optional
        Label for the y-axis (default 'Value').
    title : str, optional
        Title of the plot (default 'Two Series on One Plot').

    Raises
    ------
    ValueError
        If y1 and y2 are not the same length.
    """
    if len(y1) != len(y2):
        raise ValueError(f"Input arrays must have the same length; got {len(y1)} and {len(y2)}")

    x = np.arange(len(y1))              # common x-axis
    plt.figure()                        # new figure
    plt.plot(x, y1, label=labels[0])    # first series
    plt.plot(x, y2, label=labels[1])  # second series with dashed line
    plt.xlabel(xlabel)                  # x-axis label
    plt.ylabel(ylabel)                  # y-axis label
    plt.title(title)                    # plot title
    plt.legend()                        # show legend
    plt.grid(True)                      # optional grid
    if log: plt.yscale('log')           # set y-axis to logarithmic scale if requested
    plt.tight_layout()                  # nicely fit elements
    plt.show()                          # display

def compute_mse(preds, targets):
    preds = einops.rearrange(preds, 't n c -> t (n c)')
    targets = einops.rearrange(targets, 't n c -> t (n c)')
    mse = torch.mean((preds - targets) ** 2, dim=1).cpu().numpy()
    return mse

# n_fields = 4

In [None]:
cfg = OmegaConf.load("logs/train/runs/2025-08-11_21-39-53/.hydra/config.yaml")
model = instantiate(cfg.model)
net = instantiate(cfg.model.model)
loss_function = instantiate(cfg.model.loss_function)
model = AutoencoderLitModule.load_from_checkpoint(
    checkpoint_path="logs/train/runs/2025-08-11_21-39-53/waterdrop/kku6jcey/checkpoints/epoch=28-step=433695.ckpt",
    model=net,
    loss_function=loss_function
)
model.eval()
model.to("cuda")
dataset_AE = instantiate(cfg.data)
dataset_AE.setup(stage="autoencoder")
dataset_AE.shuffle = False
dataset_AE.batch_size = 1
dataset_AE.num_workers = 0
dataset_AE.pin_memory = False
dataset_AE.persistent_workers = False
dataset_AE.train_dataset.rollout = True
dataset_AE.val_dataset.rollout = True

In [None]:
TRAJ_IDX = [0, 1, 2]
IDX_PH = [[0], [0], [0]]

In [None]:
rollout_AE, GT_AE, GT_vel_fields_normalized_AE, MSE_fields_normalized_AE, MSE_Field_latent_AE = model.GT_encode_decode(
    particle_dm=dataset_AE,
    traj_idx=TRAJ_IDX,
    idx=IDX_PH,
    query_gt_pos=False,
    split="test",
)

In [None]:
ani = animate(
    rollout=rollout_AE[0],
    ground_truth=GT_AE[0],
    ref_frame=((0, 1), (0, 1)),
    n_skip_ahead_timesteps=1,
    start_idx=0
)

In [None]:
HTML(ani.to_jshtml())

In [None]:
ani = animate(
    rollout=rollout_AE[1],
    ground_truth=GT_AE[1],
    ref_frame=((0, 1), (0, 1)),
    n_skip_ahead_timesteps=1,
    start_idx=0
)
HTML(ani.to_jshtml())