In [None]:
from omegaconf import OmegaConf
from hydra.core.global_hydra import GlobalHydra
from hydra.utils import instantiate
from src.models.autoencoder import AutoencoderLitModule
from src.models.physics import PhysicsLitModule
from src.utils import animate

import os
from IPython.display import HTML
from tqdm import tqdm
import torch
from torch_geometric.data import Data
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation as animation
matplotlib.rcParams['animation.embed_limit'] = 400  # set 100 MB limit for animations

import numpy as np
import einops
from functools import partial
from src.datasets.particle_datamodule import ParticleDataModule

GlobalHydra.instance().clear()

os.environ["PROJECT_ROOT"] = os.path.abspath(".")

In [None]:
def plot_metric(mse, title: str, log: bool = False):
    plt.clf()
    fig, ax = plt.subplots(figsize=(10, 5))
    ax.plot(mse, label='Metric')
    ax.set_xlabel('Timestep')
    ax.set_ylabel('Metric')
    ax.set_title(f"{title} Metric Over Time")
    ax.legend()
    if log: ax.set_yscale('log')
    plt.tight_layout()
    plt.show()

def plot_two_series(y1: np.ndarray,
                    y2: np.ndarray,
                    labels=('Series 1', 'Series 2'),
                    xlabel: str = 'Index',
                    ylabel: str = 'Value',
                    title: str = 'Two Series on One Plot',
                    log: bool = False) -> None:
    """
    Plots two same-length sequences/arrays on the same axes.

    Parameters
    ----------
    y1 : np.ndarray
        First data series (plotted in default style).
    y2 : np.ndarray
        Second data series (plotted in dashed style).
    labels : tuple of str, optional
        Labels for the two series (default ('Series 1', 'Series 2')).
    xlabel : str, optional
        Label for the x-axis (default 'Index').
    ylabel : str, optional
        Label for the y-axis (default 'Value').
    title : str, optional
        Title of the plot (default 'Two Series on One Plot').

    Raises
    ------
    ValueError
        If y1 and y2 are not the same length.
    """
    if len(y1) != len(y2):
        raise ValueError(f"Input arrays must have the same length; got {len(y1)} and {len(y2)}")

    x = np.arange(len(y1))              # common x-axis
    plt.figure()                        # new figure
    plt.plot(x, y1, label=labels[0])    # first series
    plt.plot(x, y2, label=labels[1])  # second series with dashed line
    plt.xlabel(xlabel)                  # x-axis label
    plt.ylabel(ylabel)                  # y-axis label
    plt.title(title)                    # plot title
    plt.legend()                        # show legend
    plt.grid(True)                      # optional grid
    if log: plt.yscale('log')           # set y-axis to logarithmic scale if requested
    plt.tight_layout()                  # nicely fit elements
    plt.show()                          # display

def compute_mse(preds, targets):
    preds = einops.rearrange(preds, 't n c -> t (n c)')
    targets = einops.rearrange(targets, 't n c -> t (n c)')
    mse = torch.mean((preds - targets) ** 2, dim=1).cpu().numpy()
    return mse

# Overfit single trajectory
## AE model

In [None]:
cfg = OmegaConf.load("logs/train/runs/2025-06-18_21-31-26/.hydra/config.yaml")
model = instantiate(cfg.model)
net = instantiate(cfg.model.model)
loss_function = instantiate(cfg.model.loss_function)
model = AutoencoderLitModule.load_from_checkpoint(
    checkpoint_path="logs/train/runs/2025-06-18_21-31-26/waterdrop/lxx39u5h/checkpoints/epoch=388-step=6224.ckpt",
    model=net,
    loss_function=loss_function
)
model.eval()
model.to("cuda")
dataset = instantiate(cfg.data)
dataset.setup(stage="autoencoder")
dataset.shuffle = False
dataset.batch_size = 1
dataset.num_workers = 0
dataset.pin_memory = False
dataset.persistent_workers = False
dataset.train_dataset.rollout = True

In [None]:
rollout_AE, GT_AE, GT_vel_fields_normalized_AE, MSE_fields_normalized_AE = model.GT_encode_decode(
    particle_dm=dataset,
    idx=[0],
    query_gt_pos=True,
    split="train"
)

## Physics model

In [None]:
cfg          = OmegaConf.load(
    "logs/train/runs/2025-06-25_14-53-46/.hydra/config.yaml"
)

# factories
latent_model_factory = instantiate(cfg.model.latent_model, _partial_=True)
model_factory        = instantiate(cfg.model,            _partial_=True)
loss_function        = instantiate(cfg.model.loss_function)

lit_model = PhysicsLitModule.load_from_checkpoint(
    "logs/train/runs/2025-06-25_14-53-46/waterdrop_physics/fzw60sdj/checkpoints/epoch=151-step=2432.ckpt",
    latent_model = latent_model_factory,
    model        = model_factory,
    loss_function= loss_function,
    strict       = True,
    map_location = "cuda"
)
lit_model.eval()
dataset = instantiate(cfg.data)
dataset.setup(stage="physics")

dataset.val_dataset.rollout = True

In [None]:
rollout_physics, GT_physics, GT_vel_fields_normalized_physics, MSE_fields_normalized_physics, MSE_Field_latent_physics = lit_model.rollout_physics(
    particle_dm=dataset,
    idx=[0],
    query_gt_pos=True,
    split="val",
    use_gt_field=True
)

## Plotting

In [None]:
print(MSE_fields_normalized_AE.shape)
print(MSE_fields_normalized_physics.shape)

In [None]:
plot_two_series(
    MSE_fields_normalized_AE[1:-1],
    MSE_fields_normalized_physics,
    xlabel='Timestep',
    ylabel='MSE',
    title='MSE Comparison Between AE and Physics Models',
    log=True
)

In [None]:
plot_metric(
    np.clip(MSE_fields_normalized_physics - MSE_fields_normalized_AE[1:-1], 1e-10, None),
    title='MSE portion of approximator error',
    log=True
)

In [None]:
plot_metric(
    MSE_fields_normalized_physics - MSE_fields_normalized_AE[1:-1],
    title='Approximator portion of error',
    log=False
)

# All trajectories test set
## AE model

In [None]:
cfg = OmegaConf.load("logs/train/runs/2025-09-02_16-10-21/.hydra/config.yaml")
model = instantiate(cfg.model)
net = instantiate(cfg.model.model)
loss_function = instantiate(cfg.model.loss_function)
model = AutoencoderLitModule.load_from_checkpoint(
    checkpoint_path="logs/train/runs/2025-09-02_16-10-21/waterdrop/pu1byyxz/checkpoints/epoch=26-step=404595.ckpt",
    model=net,
    loss_function=loss_function
)
model.eval()
model.to("cuda")
dataset_AE = instantiate(cfg.data)
dataset_AE.setup(stage="autoencoder")
dataset_AE.shuffle = False
dataset_AE.batch_size = 1
dataset_AE.num_workers = 0
dataset_AE.pin_memory = False
dataset_AE.persistent_workers = False
dataset_AE.train_dataset.rollout = True
dataset_AE.val_dataset.rollout = True

In [None]:
rollout_AE, GT_AE, GT_vel_fields_normalized_AE, MSE_fields_normalized_AE, latents_AE = model.GT_encode_decode(
    particle_dm=dataset_AE,
    traj_idx=[0],
    idx=[[0]],
    query_gt_pos=True,
    split="test"
)

## Physics model

In [None]:
cfg          = OmegaConf.load(
    "logs/train/runs/2025-09-04_07-13-04/.hydra/config.yaml"
)

# factories
latent_model_factory = instantiate(cfg.model.latent_model, _partial_=True)
model_factory        = instantiate(cfg.model,            _partial_=True)
loss_function        = instantiate(cfg.model.loss_function)

lit_model = PhysicsLitModule.load_from_checkpoint(
    "logs/train/runs/2025-09-04_07-13-04/waterdrop_physics/w55pl0j7/checkpoints/epoch=15-step=239520.ckpt",
    latent_model = latent_model_factory,
    model        = model_factory,
    loss_function= loss_function,
    strict       = True,
    map_location = "cuda"
)
lit_model.eval()
dataset_PH = instantiate(cfg.data)
dataset_PH.setup(stage="physics")

In [None]:
rollout_physics, fields_PH, GT_physics, GT_vel_fields_normalized_physics, MSE_fields_normalized_physics, MSE_Field_latent_physics, latents_PH = lit_model.rollout_physics(
    particle_dm=dataset_PH,
    traj_idx=[0],
    idx=[[0]],
    query_gt_pos=True,
    split="test",
    use_gt_field=True
)

## Plotting

In [None]:
print(MSE_fields_normalized_AE.shape)
print(MSE_fields_normalized_physics.shape)

In [None]:
plot_two_series(
    MSE_fields_normalized_AE[1:],
    MSE_fields_normalized_physics,
    xlabel='Timestep',
    ylabel='MSE',
    title='MSE Comparison Between AE and Physics Models',
    log=True
)

In [None]:
plot_metric(
    MSE_fields_normalized_physics - MSE_fields_normalized_AE[1:],
    title='Approximator portion of error',
    log=True
)

In [None]:
plot_metric(
    MSE_fields_normalized_physics - MSE_fields_normalized_AE[1:],
    title='Approximator portion of error',
    log=False
)

## A look at multiple test trajectories

In [None]:
TRAJ_IDX = list(range(10))
IDX = [[0]]*10
rollout_AE, GT_AE, GT_vel_fields_normalized_AE, MSE_fields_normalized_AE, latents_AE = model.GT_encode_decode(
    particle_dm=dataset_AE,
    traj_idx=TRAJ_IDX,
    idx=IDX,
    query_gt_pos=True,
    split="test"
)
rollout_physics, fields_PH, GT_physics, GT_vel_fields_normalized_physics, MSE_fields_normalized_physics, MSE_Field_latent_physics, latents_PH = lit_model.rollout_physics(
    particle_dm=dataset_PH,
    traj_idx=TRAJ_IDX,
    idx=IDX,
    query_gt_pos=True,
    split="test",
    use_gt_field=True
)

In [None]:
for (MSE_field_normalized_AE, MSE_field_normalized_physics) in zip(MSE_fields_normalized_AE, MSE_fields_normalized_physics):
    plot_two_series(
        MSE_field_normalized_AE[1:],
        MSE_field_normalized_physics,
        labels=('Reconstruction MSE', 'One-step oracle MSE'),
        xlabel='Timestep',
        ylabel='MSE',
        title='MSE Comparison Between AE and Physics Models',
        log=True
    )
    plt.clf()

In [None]:
for (MSE_field_normalized_AE, MSE_field_normalized_physics) in zip(MSE_fields_normalized_AE, MSE_fields_normalized_physics):
    plot_metric(
        MSE_field_normalized_physics - MSE_field_normalized_AE[1:],
        title='Approximator portion of error',
        log=False
    )
    plt.clf()

## AE vs. AE-int

In [None]:
rollout_AE_int, GT_AE, GT_vel_fields_normalized_AE, MSE_fields_normalized_AE_int, latents_AE_int = model.GT_encode_decode(
    particle_dm=dataset_AE,
    traj_idx=TRAJ_IDX,
    idx=IDX,
    query_gt_pos=False,
    split="test"
)

In [None]:
print(MSE_fields_normalized_AE_int[0].shape)
print(MSE_fields_normalized_AE[0].shape)

In [None]:
plot_two_series(
        MSE_fields_normalized_AE_int[0],
        MSE_fields_normalized_AE[0],
        labels=('AE-int', 'AE'),
        xlabel='Timestep',
        ylabel='MSE on particle positions',
        title='Particle position error between AE and AE-int',
        log=True
    )

In [None]:
plot_two_series(
        MSE_fields_normalized_AE_int[0][:200],
        MSE_fields_normalized_AE[0][:200],
        labels=('AE-int', 'AE'),
        xlabel='Timestep',
        ylabel='MSE on particle positions',
        title='Particle position error between AE and AE-int',
        log=True
    )

In [None]:
plot_two_series(
        MSE_fields_normalized_AE_int[0][:100],
        MSE_fields_normalized_AE[0][:100],
        labels=('AE-int', 'AE'),
        xlabel='Timestep',
        ylabel='MSE on particle positions',
        title='Particle position error between AE and AE-int',
        log=True
    )

In [None]:
ani = animate(
    rollout=rollout_AE_int[0],
    ground_truth=GT_AE[0],
)

In [None]:
ani = ani.to_jshtml()

In [None]:
HTML(ani)

# Aggregated statistics

In [None]:
TRAJ_IDX = list(range(dataset_AE.get_dataset(split="test").n_traj)) # list of all test trajectories
IDX = [[0] for _ in range(len(TRAJ_IDX))] # starting indeces (we want whole trajectories)

In [None]:
rollout_AE, GT_AE, GT_vel_fields_normalized_AE, MSE_fields_normalized_AE, latents_AE = model.GT_encode_decode(
    particle_dm=dataset_AE,
    traj_idx=TRAJ_IDX,
    idx=IDX,
    query_gt_pos=True,
    split="test"
)
rollout_physics, fields_PH, GT_physics, GT_vel_fields_normalized_physics, MSE_fields_normalized_physics, MSE_Field_latent_physics, propagated_latents = lit_model.rollout_physics(
    particle_dm=dataset_PH,
    traj_idx=TRAJ_IDX,
    idx=IDX,
    query_gt_pos=True,
    split="test",
    use_gt_field=True
)

In [None]:
# Create diff
Diff_MSE = [MSE_field_normalized_physics - MSE_field_normalized_AE[1:] for (MSE_field_normalized_AE, MSE_field_normalized_physics) in zip(MSE_fields_normalized_AE, MSE_fields_normalized_physics)]

In [None]:
# save the lists of tensors
torch.save(MSE_fields_normalized_AE, "D:/Projects/Master/UPT/data/GT_rollout_data_2_1_1_prac/MSE_fields_normalized_AE.pt")
torch.save(MSE_fields_normalized_physics, "D:/Projects/Master/UPT/data/GT_rollout_data_2_1_1_prac/MSE_fields_normalized_physics.pt")
torch.save(MSE_Field_latent_physics, "D:/Projects/Master/UPT/data/GT_rollout_data_2_1_1_prac/MSE_Field_latent_physics.pt")
torch.save(Diff_MSE, "D:/Projects/Master/UPT/data/GT_rollout_data_2_1_1_prac/Diff_MSE.pt")
torch.save(latents_AE, "D:/Projects/Master/UPT/data/GT_rollout_data_2_1_1_prac/latents_AE.pt")
torch.save(propagated_latents, "D:/Projects/Master/UPT/data/GT_rollout_data_2_1_1_prac/propagated_latents.pt")

In [None]:
# load the lists of tensors
MSE_fields_normalized_AE = torch.load("MSE_fields_normalized_AE.pt")
MSE_Field_latent_physics = torch.load("MSE_Field_latent_physics.pt")
MSE_fields_normalized_physics = torch.load("MSE_fields_normalized_physics.pt")
Diff_MSE = torch.load("Diff_MSE.pt")
latents_AE = torch.load("latents_AE.pt")
propagated_latents = torch.load("propagated_latents.pt")

In [None]:
plot_two_series(
        MSE_fields_normalized_AE[0][1:],
        MSE_fields_normalized_physics[0],
        labels=('Reconstruction MSE', 'One-step oracle MSE'),
        xlabel='Timestep',
        ylabel='MSE',
        title='MSE Comparison Between AE and Physics Models',
        log=True
    )

In [None]:
def rollout_heatmap(rollout_errors, *, drop_first=False, cmap="viridis",
                    vmin=None, vmax=None, norm=None,
                    ax=None, cbar=True, cbar_kw=None, title=None):
    """
    Visualise rollout errors as a heat-map.

    Parameters
    ----------
    rollout_errors : list[Tensor | array_like]
        One element per rollout; each element is 1-D (length = #timesteps)
        or 2-D/3-D (e.g. extra feature dims).  Non-1-D tensors are flattened.
    drop_first : bool, optional
        If True, remove the very first error of every rollout *before* plotting
        (useful when step 0 is the auto-encoder reconstruction loss, etc.).
    cmap : str or Colormap, optional
        Matplotlib colormap to use (default "viridis").
    ax : matplotlib.axes.Axes, optional
        Draw on this axes if provided, else a new figure is created.
    cbar : bool, optional
        Attach a color-bar showing the error scale.
    cbar_kw : dict or None
        Extra kwargs forwarded to `fig.colorbar`.
    title : str or None
        Optional plot title.
    """
    if not rollout_errors:
        raise ValueError("Input list is empty.")

    arr = torch.stack(rollout_errors).detach().cpu().numpy()

    # Keep only the first axis = rollout, second = timestep; flatten extras
    n_rollouts, n_timesteps = arr.shape[:2]
    arr = arr.reshape(n_rollouts, n_timesteps, -1).squeeze(-1)

    # Optionally drop leading timestep
    if drop_first:
        arr = arr[:, 1:]
        n_timesteps -= 1

    # --- plotting ---
    if ax is None:
        fig, ax = plt.subplots(figsize=(0.4 * n_timesteps + 2,
                                        0.25 * n_rollouts + 2))

    im = ax.imshow(arr,
                   aspect="auto", origin="upper",
                   cmap=cmap,
                   interpolation="nearest",
                   vmin=vmin, vmax=vmax, norm=norm)

    ax.set_xlabel("Timestep")
    ax.set_ylabel("Rollout index")
    ax.set_xticks(np.arange(n_timesteps))
    ax.set_yticks(np.arange(n_rollouts))
    ax.set_title(title or "Rollout error heat-map")

    # Nice tick labels (0,1,2…)
    ax.set_xticklabels([str(t) for t in range(n_timesteps)])
    ax.set_yticklabels([str(i) for i in range(n_rollouts)])

    if cbar:
        cbar_kw = {} if cbar_kw is None else cbar_kw
        plt.colorbar(im, ax=ax, **cbar_kw).set_label("Error (MSE)")

    plt.tight_layout()
    return ax

In [None]:
datasets = [MSE_fields_normalized_AE, MSE_fields_normalized_physics, Diff_MSE]

global_min = min(float(torch.min(torch.stack(ds)) if torch.is_tensor(ds[0]) else
                       np.min(np.stack(ds)))
                 for ds in datasets)
global_max = max(float(torch.max(torch.stack(ds)) if torch.is_tensor(ds[0]) else
                       np.max(np.stack(ds)))
                 for ds in datasets)

In [None]:
rollout_heatmap(MSE_fields_normalized_AE, cmap="viridis", vmin=global_min, vmax=global_max)
plt.show()

In [None]:
rollout_heatmap(MSE_fields_normalized_physics, cmap="viridis", vmin=global_min, vmax=global_max)
plt.show()

In [None]:
rollout_heatmap(Diff_MSE, cmap="viridis", vmin=global_min, vmax=global_max)
plt.show()

## Latents

In [None]:
len(latents_AE)

In [None]:
latents_AE_diffs = []
for traj_latent in latents_AE:
    diff_latents = (traj_latent[1:].squeeze() - traj_latent[:-1].squeeze())**2
    latents_AE_diffs.append(diff_latents.mean(axis=(1, 2)))

In [None]:
latents_AE_diffs[0].shape

In [None]:
from scipy.stats import pearsonr

In [None]:
x = torch.stack(MSE_Field_latent_physics)
y = torch.stack(latents_AE_diffs)
print(x.shape, y.shape)

In [None]:
x = einops.rearrange(x, 't m -> (t m)').numpy(force=True)
y = einops.rearrange(y, 't m -> (t m)').numpy(force=True)

In [None]:
r, p = pearsonr(x, y)

In [None]:
coeffs = np.polyfit(x, y, deg=1)
slope, intercept = coeffs

# regression line values
x_line = np.linspace(x.min(), x.max(), 100)
y_line = slope * x_line + intercept

In [None]:
plt.figure(figsize=(8, 5), dpi=150)
plt.scatter(x, y, alpha=0.7)
plt.xlabel("One-step oracle error")
plt.ylabel("Squared difference of latents")
plt.tight_layout()
plt.show()

In [None]:
print(r, p)

### Plotvorschlag von Andreas (Steigung 1 perfekt)

In [None]:
z_i_plus_1_minus_z_i = (torch.stack(propagated_latents).cpu() - torch.stack(latents_AE)[:, 1:].cpu())**2

In [None]:
z_i_plus_1_minus_z_i_mse = z_i_plus_1_minus_z_i.mean(axis=(2, 3, 4))
z_i_plus_1_minus_z_i_mse.shape

In [None]:
x = einops.rearrange(z_i_plus_1_minus_z_i_mse, 't m -> (t m)').numpy(force=True)
y = einops.rearrange(latents_AE_diffs, 't m -> (t m)').numpy(force=True)

In [None]:
r, p = pearsonr(x, y)

In [None]:
coeffs = np.polyfit(x, y, deg=1)
slope, intercept = coeffs

# regression line values
x_line = np.linspace(x.min(), x.max(), 100)
y_line = slope * x_line + intercept

In [None]:
plt.figure(figsize=(8, 5), dpi=150)
plt.scatter(x, y, alpha=0.7)
plt.plot(x_line, y_line, linestyle="--", label=f"least squares fit: y = {slope:.3f}x + {intercept:.3f}", color='red')
plt.title(r"$(\boldsymbol{z}_{i+1}-\boldsymbol{z}_{i})^2$ vs. $(\hat{\boldsymbol{z}}_{i+1}-\boldsymbol{z}_i)^2$" + f", pearson r = {r:.3f} (p={p:.2g})")
plt.xlabel("Latent MSE")
plt.ylabel("Mean squared difference of latents")
plt.legend()
plt.tight_layout()
plt.show()

Aber dieser Plot ja nicht sehr aussagekräftig? Weil Sinn der Korrelation sollte ja das Finden eines Zusammenhangs zwischen latent Dynamik und Approximator Performance sein? Was soll das hier aussagen?

In [None]:
GT_vel_fields_normalized_AE[0].shape

In [None]:
GT_vel_fields_normalized_AE_cleaned = [elem[:, :, 1].abs().mean(1) for elem in GT_vel_fields_normalized_AE]

In [None]:
GT_vel_fields_normalized_AE_cleaned = torch.stack(GT_vel_fields_normalized_AE_cleaned)

In [None]:
GT_vel_fields_normalized_AE_cleaned.shape

In [None]:
GT_vel_fields_normalized_AE_mean_x = (GT_vel_fields_normalized_AE_cleaned[..., 0])
print(GT_vel_fields_normalized_AE_mean_x.shape)

In [None]:
GT_vel_fields_normalized_AE_mean_y = (GT_vel_fields_normalized_AE_cleaned[..., 1])
print(GT_vel_fields_normalized_AE_mean_y.shape)

In [None]:
MSE_field_normalized_AE_stacked = torch.stack(MSE_fields_normalized_AE, dim=0)
print(MSE_field_normalized_AE_stacked.shape)

In [None]:
MSE_field_normalized_AE_stacked_numpy = MSE_field_normalized_AE_stacked.cpu().numpy()
print(MSE_field_normalized_AE_stacked_numpy.shape)

In [None]:

GT_vel_fields_normalized_AE_mean_x_numpy = GT_vel_fields_normalized_AE_mean_x.cpu().numpy()
GT_vel_fields_normalized_AE_mean_y_numpy = GT_vel_fields_normalized_AE_mean_y.cpu().numpy()

In [None]:
x = MSE_field_normalized_AE_stacked_numpy.flatten()
y = GT_vel_fields_normalized_AE_mean_x_numpy.flatten()

In [None]:
r, p = pearsonr(x, y)

In [None]:
coeffs = np.polyfit(x, y, deg=1)
slope, intercept = coeffs

# regression line values
x_line = np.linspace(x.min(), x.max(), 100)
y_line = slope * x_line + intercept

In [None]:
plt.figure(figsize=(8, 5), dpi=150)
plt.scatter(x, y, alpha=0.7)
plt.plot(x_line, y_line, linestyle="--", label=f"least squares fit: y = {slope:.3f}x + {intercept:.3f}", color='red')
plt.title(f"Mean field strength in x-direction vs. field MSE\nPearson r = {r:.3f} (p={p:.2g})")
plt.xlabel("Field MSE")
plt.ylabel("Mean field strength in x-direction")
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
x = MSE_field_normalized_AE_stacked_numpy.flatten()
y = GT_vel_fields_normalized_AE_mean_y_numpy.flatten()

In [None]:
r, p = pearsonr(x, y)

In [None]:
coeffs = np.polyfit(x, y, deg=1)
slope, intercept = coeffs

# regression line values
x_line = np.linspace(x.min(), x.max(), 100)
y_line = slope * x_line + intercept

In [None]:
plt.figure(figsize=(8, 5), dpi=150)
plt.scatter(x, y, alpha=0.7)
plt.plot(x_line, y_line, linestyle="--", label=f"least squares fit: y = {slope:.3f}x + {intercept:.3f}", color='red')
plt.title(f"Mean field strength in y-direction vs. field MSE\nPearson r = {r:.3f} (p={p:.2g})")
plt.xlabel("Field MSE")
plt.ylabel("Mean field strength in y-direction")
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
MSE_field_normalized_AE_stacked_numpy.shape

In [None]:
MSE_fields_normalized_AE_mean = MSE_field_normalized_AE_stacked_numpy.mean(axis=0)
MSE_fields_normalized_AE_std = MSE_field_normalized_AE_stacked_numpy.std(axis=0)

In [None]:
plt.figure(figsize=(10, 5))
plt.bar(range(len(MSE_fields_normalized_AE_mean)), MSE_fields_normalized_AE_mean)
plt.xlabel('Timestep')
plt.ylabel('Mean MSE')
plt.title('Mean MSE per Timestep')

In [None]:
plt.figure(figsize=(10, 5))
plt.bar(range(len(MSE_fields_normalized_AE_std)), MSE_fields_normalized_AE_std)
plt.xlabel('Timestep')
plt.ylabel('MSE Variance')
plt.title('MSE Variance per Timestep')

In [None]:
MSE_fields_normalized_AE_mean_var_ratio = MSE_fields_normalized_AE_mean/MSE_fields_normalized_AE_std

In [None]:
plt.figure(figsize=(10, 5))
plt.bar(range(len(MSE_fields_normalized_AE_mean_var_ratio)), MSE_fields_normalized_AE_mean_var_ratio)
plt.xlabel('Timestep')
plt.ylabel('MSE Variance')
plt.title('MSE Variance per Timestep')

In [None]:
def plot_mean_with_variance(mean, var, x=None, xlabel="Index", ylabel="Value", lower_bound=1e-14, log=True):
    """
    Plot a line for the mean values with a shaded ±variance envelope.

    Parameters
    ----------
    mean : array‑like
        Sequence of mean values.
    var : array‑like
        Sequence of variances (same length as `mean`).
    x : array‑like or None, optional
        X‑coordinates. If None, uses 0…N‑1.
    """
    mean = np.asarray(mean)
    var = np.asarray(var)

    if x is None:
        x = np.arange(len(mean))

    lower = mean - var
    lower = np.maximum(lower, lower_bound)
    upper = mean + var

    fig, ax = plt.subplots(figsize=(10, 4), dpi=300)
    ax.plot(x, mean, label="Mean", color='blue', linewidth=0.5)
    ax.fill_between(x, lower, upper, alpha=0.75, label="Mean ± Std")
    ax.plot(x, lower, color='blue', linestyle='--', linewidth=0.2)
    ax.plot(x, upper, color='blue', linestyle='--', linewidth=0.2)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.xaxis.grid(True)
    # log scale
    if log: ax.set_yscale("log")
    plt.show()

In [None]:
MSE_Field_latent_physics_stacked = torch.stack(MSE_Field_latent_physics, dim=0)
MSE_fields_normalized_physics = torch.stack(MSE_fields_normalized_physics, dim=0)

In [None]:
MSE_Field_latent_physics_stacked_mean = MSE_Field_latent_physics_stacked.mean(axis=0).cpu().numpy()
MSE_Field_latent_physics_stacked_std = MSE_Field_latent_physics_stacked.std(axis=0).cpu().numpy()
MSE_fields_normalized_physics_mean = MSE_fields_normalized_physics.mean(axis=0).cpu().numpy()
MSE_fields_normalized_physics_std = MSE_fields_normalized_physics.std(axis=0).cpu().numpy()

In [None]:
plot_mean_with_variance(
    MSE_Field_latent_physics_stacked_mean,
    MSE_Field_latent_physics_stacked_std,
    xlabel="Rollout Timestep",
    ylabel="One-step oracle latent error",
    lower_bound=1e-3, # got a weird numeric outlier around timestep 170
)

In [None]:
plot_mean_with_variance(
    MSE_Field_latent_physics_stacked_mean,
    MSE_Field_latent_physics_stacked_std,
    xlabel="Rollout Timestep",
    ylabel="One-step oracle latent error",
)

In [None]:
MSE_fields_normalized_AE = torch.stack(MSE_fields_normalized_AE, dim=0)

In [None]:
MSE_fields_normalized_AE.shape

In [None]:
MSE_fields_normalized_AE_mean = MSE_fields_normalized_AE.mean(axis=0).cpu().numpy()
MSE_fields_normalized_AE_std = MSE_fields_normalized_AE.std(axis=0).cpu().numpy()

In [None]:
plot_mean_with_variance( # after retraining... looks almost identical
    MSE_fields_normalized_AE_mean,
    MSE_fields_normalized_AE_std,
    xlabel="Rollout Timestep",
    ylabel="Autoencoding error",
    lower_bound=1e-6,
    log=False
)

In [None]:
plot_mean_with_variance(
    MSE_fields_normalized_AE_mean,
    MSE_fields_normalized_AE_std,
    xlabel="Rollout Timestep",
    ylabel="Autoencoding error",
    lower_bound=1e-6,
    log=False
)

In [None]:
plot_mean_with_variance(
    MSE_fields_normalized_AE_mean,
    MSE_fields_normalized_AE_std,
    xlabel="Rollout Timestep",
    ylabel="Autoencoding error",
)

In [None]:
plot_mean_with_variance(
    MSE_fields_normalized_physics_mean,
    MSE_fields_normalized_physics_std,
    xlabel="Rollout timestep",
    ylabel="One-step oracle error",
    lower_bound=1e-6,
    log=False
)
plot_mean_with_variance(
    MSE_fields_normalized_AE_mean,
    MSE_fields_normalized_AE_std,
    xlabel="Rollout timestep",
    ylabel="Autoencoding error",
    lower_bound=1e-6,
    log=False
)

In [None]:
plot_mean_with_variance(
    MSE_fields_normalized_physics_mean,
    MSE_fields_normalized_physics_std,
    xlabel="Rollout timestep",
    ylabel="One-step oracle error",
)
plot_mean_with_variance(
    MSE_fields_normalized_AE_mean,
    MSE_fields_normalized_AE_std,
    xlabel="Rollout timestep",
    ylabel="Autoencoding error",
)

In [None]:
MSE_fields_normalized_physics.shape

In [None]:
diff = MSE_fields_normalized_physics - MSE_fields_normalized_AE[:, 1:]

In [None]:
diff_mean = diff.mean(axis=0).cpu().numpy()
diff_std = diff.std(axis=0).cpu().numpy()

In [None]:
plot_mean_with_variance(
    diff_mean,
    diff_std,
    xlabel="Rollout timestep",
    ylabel="Approximator error",
)

## Approximation quality

In [None]:
approximator_errors = []
for (MSE_field_normalized_AE, MSE_field_normalized_physics) in zip(MSE_fields_normalized_AE, MSE_fields_normalized_physics):
    approximator_errors.append(MSE_field_normalized_physics - MSE_field_normalized_AE[1:])

In [None]:
approximator_errors = torch.stack(approximator_errors, dim=0)

In [None]:
approximator_errors.mean()

# Rollout analysis

In [None]:
TRAJ_IDX = list(range(dataset_AE.get_dataset(split="test").n_traj)) # list of all test trajectories
START_IDXS_PER_TRAJ = 100
FIXED_LENGTH = 80 # length for each rollout
n_per_traj = dataset_PH.get_dataset(split="test").n_per_traj

In [None]:
IDX = []
for i in range(len(TRAJ_IDX)):
    traj_idx_list = []
    for j in range(START_IDXS_PER_TRAJ):
        # sample from 0 to dataset_PH.get_dataset(split="test").n_per_traj - FIXED_LENGTH
        idx = np.random.randint(0, n_per_traj - FIXED_LENGTH)
        traj_idx_list.append(idx)
    IDX.append(traj_idx_list)

In [None]:
rollout_physics, GT_physics, GT_vel_fields_normalized_physics, MSE_fields_normalized_physics, MSE_Field_latent_physics, propagated_latents = lit_model.rollout_physics(
    particle_dm=dataset_PH,
    traj_idx=TRAJ_IDX,
    idx=IDX,
    query_gt_pos=False,
    split="test",
    use_gt_field=False,
    fixed_length=FIXED_LENGTH
)

In [None]:
torch.save(MSE_fields_normalized_physics, "D:/Projects/Master/UPT/data/rollout_data/MSE_fields_normalized_physics.pt")
torch.save(MSE_Field_latent_physics, "D:/Projects/Master/UPT/data/rollout_data/MSE_Field_latent_physics.pt")
torch.save(propagated_latents, "D:/Projects/Master/UPT/data/rollout_data/propagated_latents.pt")
torch.save(rollout_physics, "D:/Projects/Master/UPT/data/rollout_data/rollout_physics.pt")

In [None]:
rollout_AE, GT_AE, GT_vel_fields_normalized_AE, MSE_fields_normalized_AE, latents_AE = model.GT_encode_decode(
    particle_dm=dataset_AE,
    traj_idx=TRAJ_IDX,
    idx=IDX,
    query_gt_pos=True,
    split="test",
    fixed_length=FIXED_LENGTH
)

In [None]:
torch.save(latents_AE, "D:/Projects/Master/UPT/data/rollout_data/latents_AE.pt")

## Rollout 10 timesteps

In [None]:
TRAJ_IDX = list(range(dataset_AE.get_dataset(split="test").n_traj)) # list of all test trajectories
FIXED_LENGTH = 10 # length for each rollout
n_per_traj_PH = dataset_PH.get_dataset(split="test").n_per_traj
n_per_traj_AE = dataset_AE.get_dataset(split="test").n_per_traj
IDX_PH = [list(range(n_per_traj_PH - FIXED_LENGTH)) for _ in range(len(TRAJ_IDX))]  # starting indeces (all start idxes which allow length of FIXED_LENGTH)
IDX_AE = [list(range(n_per_traj_AE - FIXED_LENGTH)) for _ in range(len(TRAJ_IDX))]  # starting indeces (all start idxes which allow length of FIXED_LENGTH)

In [None]:
rollout_physics, fields, GT_physics, GT_vel_fields_normalized_physics, MSE_fields_normalized_physics, MSE_Field_latent_physics, propagated_latents = lit_model.rollout_physics(
    particle_dm=dataset_PH,
    traj_idx=TRAJ_IDX,
    idx=IDX_PH,
    query_gt_pos=False,
    split="test",
    use_gt_field=False,
    fixed_length=FIXED_LENGTH
)

In [None]:
torch.save(MSE_field_normalized_physics, "D:/Projects/Master/UPT/data/10_step_rollout_data_2_1_1_prac/MSE_field_normalized_physics.pt")

In [None]:
# MSE_fields_normalized_physics = torch.load("D:/Projects/Master/UPT/data/rollout_data_every_start_10_timesteps/MSE_fields_normalized_physics.pt")

In [None]:
MSE_field_normalized_physics = torch.stack(MSE_fields_normalized_physics, dim=0)

In [None]:
MSE_field_normalized_physics = einops.rearrange(MSE_field_normalized_physics, '(n T) t -> n t T', n=len(TRAJ_IDX))

In [None]:
MSE_field_normalized_physics.shape

In [None]:
MSE_field_normalized_physics_mean = MSE_field_normalized_physics.mean(axis=0).cpu().numpy()
MSE_field_normalized_physics_std = MSE_field_normalized_physics.std(axis=0).cpu().numpy()

In [None]:
plot_mean_with_variance(
    MSE_field_normalized_physics_mean[0],
    MSE_field_normalized_physics_std[0],
    xlabel="Rollout Timestep",
    ylabel="0th rollout step error",
    log=False
)

In [None]:
plot_mean_with_variance(
    MSE_field_normalized_physics_mean[2],
    MSE_field_normalized_physics_std[2],
    xlabel="Rollout Timestep",
    ylabel="2nd rollout step error",
    log=False
)

In [None]:
plot_mean_with_variance(
    MSE_field_normalized_physics_mean[4],
    MSE_field_normalized_physics_std[4],
    xlabel="Rollout Timestep",
    ylabel="5th rollout step error",
    log=False
)

## Rollout 100 timesteps, each starting position

In [None]:
TRAJ_IDX = list(range(dataset_AE.get_dataset(split="test").n_traj)) # list of all test trajectories
FIXED_LENGTH = 70 # length for each rollout
n_per_traj_PH = dataset_PH.get_dataset(split="test").n_per_traj
n_per_traj_AE = dataset_AE.get_dataset(split="test").n_per_traj
IDX_PH = [list(range(n_per_traj_PH - FIXED_LENGTH)) for _ in range(len(TRAJ_IDX))]  # starting indeces (all start idxes which allow length of FIXED_LENGTH)
IDX_AE = [list(range(n_per_traj_AE - FIXED_LENGTH)) for _ in range(len(TRAJ_IDX))]  # starting indeces (all start idxes which allow length of FIXED_LENGTH)

In [None]:
rollout_physics, GT_physics, GT_vel_fields_normalized_physics, MSE_fields_normalized_physics, MSE_Field_latent_physics, propagated_latents = lit_model.rollout_physics(
    particle_dm=dataset_PH,
    traj_idx=TRAJ_IDX,
    idx=IDX_PH,
    query_gt_pos=False,
    split="test",
    use_gt_field=False,
    fixed_length=FIXED_LENGTH
)

In [None]:
# torch.save(GT_physics, "D:/Projects/Master/UPT/data/rollout_data_every_start_100_timesteps/GT_physics.pt")
rollout_physics = torch.load("D:/Projects/Master/UPT/data/rollout_data_every_start_100_timesteps/rollout_physics.pt")

In [None]:
GT_physics = torch.load("D:/Projects/Master/UPT/data/rollout_data_every_start_100_timesteps/GT_physics.pt")

In [None]:
torch.save(MSE_fields_normalized_physics, "D:/Projects/Master/UPT/data/rollout_data_every_start_100_timesteps/MSE_fields_normalized_physics.pt")
torch.save(MSE_Field_latent_physics, "D:/Projects/Master/UPT/data/rollout_data_every_start_100_timesteps/MSE_Field_latent_physics.pt")
# torch.save(propagated_latents, "D:/Projects/Master/UPT/data/rollout_data_every_start_10_timesteps/propagated_latents.pt")
torch.save(rollout_physics, "D:/Projects/Master/UPT/data/rollout_data_every_start_100_timesteps/rollout_physics.pt")

In [None]:
MSE_Field_latent_physics = torch.load("D:/Projects/Master/UPT/data/rollout_data_every_start_100_timesteps/MSE_Field_latent_physics.pt")

In [None]:
rollout_AE, GT_AE, GT_vel_fields_normalized_AE, MSE_fields_normalized_AE, latents_AE = model.GT_encode_decode(
    particle_dm=dataset_AE,
    traj_idx=TRAJ_IDX,
    idx=IDX_AE,
    query_gt_pos=False,
    split="test",
    fixed_length=FIXED_LENGTH
)

In [None]:
# torch.save(latents_AE, "D:/Projects/Master/UPT/data/rollout_data_every_start_100_timesteps/latents_AE.pt")
# torch.save(MSE_fields_normalized_AE, "D:/Projects/Master/UPT/data/rollout_data_every_start_100_timesteps/MSE_fields_normalized_AE.pt")
# torch.save(rollout_AE, "D:/Projects/Master/UPT/data/rollout_data_every_start_100_timesteps/rollout_AE_int_70_steps.pt")
# torch.save(GT_AE, "D:/Projects/Master/UPT/data/rollout_data_every_start_100_timesteps/GT_AE_70_steps.pt")

In [None]:
rollout_AE = torch.load("D:/Projects/Master/UPT/data/rollout_data_every_start_100_timesteps/rollout_AE_int_70_steps.pt")
GT_AE = torch.load("D:/Projects/Master/UPT/data/rollout_data_every_start_100_timesteps/GT_AE_70_steps.pt")

In [None]:
propagated_latents = torch.load("D:/Projects/Master/UPT/data/rollout_data_every_start_10_timesteps/propagated_latents.pt")
latents_AE = torch.load("D:/Projects/Master/UPT/data/rollout_data_every_start_10_timesteps/latents_AE.pt")

### Mean and std per rollout setp

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def plot_means_variances(
    means: np.ndarray,
    stds: np.ndarray,
    xlabel: str,
    ylabel: str,
    ax: plt.Axes | None = None,
    *,
    color: str = "red",
    upper_bond: float | None = None,
    lower_bound: float | None = None,
    line_kw: dict | None = None,
    fill_kw: dict | None = None,
    lower_value: float | None = None,  # new argument for scaling
):
    """
    Plot a sequence of means with ±1 SD envelopes.

    Parameters
    ----------
    mean_var : (n, 2) ndarray
    Column 0 = means, Column 1 = variances.
    ax : matplotlib.axes.Axes, optional
    Axis to draw on.  If None, a new figure+axis is created.
    color : str, default 'red'
    Base color for the styling.
    line_kw : dict, optional
    Extra kwargs forwarded to the central mean line.
    fill_kw : dict, optional
    Extra kwargs forwarded to `fill_between`.
    lower_value : float, optional
    Minimum value for lower bound scaling.

    Returns
    -------
    matplotlib.axes.Axes
    The axis containing the plot.
    """
    x     = np.arange(1, means.shape[0] + 1)

    # sensible default styles
    line_kw = dict(line_kw or {}, lw=2.5, zorder=3, label="mean")
    fill_kw = dict(fill_kw or {}, alpha=0.18, linewidth=0, zorder=1)

    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 4), dpi=200)
    # central mean line
    ax.plot(x, means, color=color, **line_kw)

    # upper & lower SD bounds
    upper = means + stds
    lower = means - stds
    if lower_bound is not None:
        lower = np.clip(lower, a_min=lower_bound, a_max=None)
    if upper_bond is not None:
        upper = np.clip(upper, a_min=None, a_max=upper_bond)

    ax.plot(x, upper, color=color, alpha=0.55, lw=1.2, zorder=2)
    ax.plot(x, lower, color=color, alpha=0.55, lw=1.2, zorder=2)

    # filled envelope
    ax.fill_between(x, lower, upper, color=color, **fill_kw)

    # cosmetics
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.margins(x=0.02)
    ax.grid(True, which="both", ls=":", lw=0.5, zorder=0)
    if lower_value is not None:
        # keep the current upper limit unless the plotted data exceed it
        ymin, ymax = ax.get_ylim()
        ax.set_ylim(bottom=lower_value, top=max(ymax, upper.max()))
    return ax


In [None]:
propagated_latents = torch.stack(propagated_latents, dim=0)

In [None]:
latents_AE = torch.stack(latents_AE, dim=0)

In [None]:
propagated_latents.shape

In [None]:
latents_AE.shape

In [None]:
MSE_Field_latent_physics = torch.stack(MSE_Field_latent_physics, dim=0)

In [None]:
MSE_Field_latent_physics.shape

In [None]:
MSE_Field_latent_means = MSE_Field_latent_physics.mean(axis=0).cpu().numpy()
MSE_Field_latent_stds = MSE_Field_latent_physics.std(axis=0).cpu().numpy()

In [None]:
plot_means_variances(
    MSE_Field_latent_means,
    MSE_Field_latent_stds
)
plt.show()

In [None]:
plot_means_variances(
    MSE_Field_latent_means[:10],
    MSE_Field_latent_stds[:10]
)
plt.show()

### Correlation and IoU per timestep???

In [None]:
from src.utils.metric import calc_mean_iou, mean_iou, calc_correlation

In [None]:
print(len(GT_physics))
print(len(rollout_physics))

In [None]:
correlations = []
for (GT, rollout) in tqdm(zip(GT_physics, rollout_physics), total=len(GT_physics)):
    correlations.append(calc_correlation(rollout.transpose(0, 1), GT.transpose(0, 1)))

In [None]:
print(correlations[0].shape)
print(correlations[0][99])

In [None]:
correlations = torch.stack(correlations, dim=0)

In [None]:
correlations.shape

In [None]:
correlations_mean = correlations.mean(axis=0).cpu().numpy()
correlation_std = correlations.std(axis=0).cpu().numpy()

In [None]:
plot_means_variances(
    correlations_mean,
    correlation_std,
    xlabel="Rollout Timestep",
    ylabel="Correlation between GT and rollout",
    upper_bond=1.0,
)
plt.show()

In [None]:
len(rollout_physics)

In [None]:
rollout_physics[0].shape

In [None]:
ious = []
for (GT, rollout) in tqdm(zip(GT_physics, rollout_physics), total=len(GT_physics)):
    iou_traj = []
    for (GT_pos, rollout_pos) in zip(GT, rollout):
        # Calculate IoU for each pair of GT and rollout positions
        iou_traj.append(mean_iou(rollout_pos.unsqueeze(0), GT_pos.unsqueeze(0), n_compartments=32, bounding_box=(-0.9, 1.9, -0.9, 1.9)))
    ious.append(torch.stack(iou_traj))

In [None]:
torch.save(ious, "D:/Projects/Master/UPT/data/rollout_data_every_start_100_timesteps/ious_32_big_area.pt")
# ious = torch.load("D:/Projects/Master/UPT/data/rollout_data_every_start_100_timesteps/ious.pt")

In [None]:
len(ious)

In [None]:
ious = torch.stack(ious, dim=0)

In [None]:
ious.shape

In [None]:
ious_mean = ious.mean(axis=0).cpu().numpy()
ious_std = ious.std(axis=0).cpu().numpy()

In [None]:
mean_iou(GT_physics[0][0].unsqueeze(0), GT_physics[0][0].unsqueeze(0), n_compartments=400, bounding_box=(-0.9, 1.9, -0.9, 1.9))

In [None]:
plot_means_variances(
    ious_mean,
    ious_std,
    xlabel="Rollout Timestep",
    ylabel="IoU",
    upper_bond=1.0,
    lower_bound=0.0,
)
plt.show()

### Ground truth latent integration comparison

In [None]:
correlations = []
for (GT, rollout) in tqdm(zip(GT_AE, rollout_AE), total=len(GT_AE)):
    correlations.append(calc_correlation(rollout.transpose(0, 1), GT.transpose(0, 1)))

In [None]:
correlations = torch.stack(correlations, dim=0).numpy(force=True)

In [None]:
torch.save(correlations, "D:/Projects/Master/UPT/data/rollout_data_every_start_100_timesteps/correlations_AE_70_steps.pt")

In [None]:
plot_means_variances(
    correlations.mean(axis=0),
    correlations.std(axis=0),
    xlabel="Rollout Timestep",
    ylabel="Correlation between GT and rollout",
    upper_bond=1.0,
    lower_value=0.84
)

In [None]:
ious = []
for (GT, rollout) in tqdm(zip(GT_AE, rollout_AE), total=len(rollout_AE)):
    iou_traj = []
    for (GT_pos, rollout_pos) in zip(GT, rollout):
        # Calculate IoU for each pair of GT and rollout positions
        iou_traj.append(mean_iou(rollout_pos.unsqueeze(0), GT_pos.unsqueeze(0), n_compartments=50, bounding_box=(-0.9, 1.9, -0.9, 1.9)))
    ious.append(torch.stack(iou_traj))

In [None]:
ious = torch.stack(ious, dim=0).numpy(force=True)

In [None]:
torch.save(ious, "D:/Projects/Master/UPT/data/rollout_data_every_start_100_timesteps/ious_AE_70_steps.pt")

In [None]:
plot_means_variances(
    ious.mean(axis=0),
    ious.std(axis=0),
    xlabel="Rollout Timestep",
    ylabel="IoU",
    upper_bond=1.0,
    lower_value=0.7
)