# MNIST Latent Transport: Kernel EDMD + LAWGD

## Objective
1. Load high-dimensional MNIST images and map them into a low-dimensional latent space
2. Use PCA/SVD-style dimensionality reduction to obtain latent coordinates
3. Define a latent potential in the reduced space and learn the Langevin generator via Kernel EDMD / SDMD
4. Apply LAWGD to transport off-manifold particles so that they match the latent MNIST distribution

## Motivation
- Directly running LAWGD on 28×28 images is infeasible, so we compress the data
- The latent space keeps the downstream DMPS/SDMD workflows unchanged (still 2D)
- Decoder (inverse PCA) lets us lift transported particles back to pixel space for inspection

In [1]:
# Import required libraries
import time
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from scipy.linalg import svd, eig
from scipy.stats import gaussian_kde
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from torchvision import datasets

# Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)

print("Libraries imported successfully!")

Libraries imported successfully!


## MNIST Latent Representation Pipeline

We treat MNIST images as points in \(\mathbb{R}^{784}\) and learn a low-dimensional chart:
- Normalize pixels to \([0, 1]\)
- Standardize flattened vectors, then run PCA (2D latent space by default)
- Keep the encoder/decoder so that we can move between latent and pixel spaces
- Define a quadratic latent potential using the empirical covariance (acts like an isotropic energy landscape around the data manifold)

In [2]:
class MNISTLatentPipeline:
    """Build a PCA-based encoder/decoder for MNIST images."""

    def __init__(self, data_root: str = "data/MNIST", latent_dim: int = 2):
        self.data_root = Path(data_root)
        self.latent_dim = latent_dim
        self.scaler: StandardScaler | None = None
        self.pca: PCA | None = None

    def load_data(self, n_samples: int | None = 60000, train: bool = True):
        dataset = datasets.MNIST(root=str(self.data_root), train=train, download=True)
        images = dataset.data.numpy().astype(np.float32) / 255.0
        labels = dataset.targets.numpy()

        if n_samples is not None and n_samples < len(images):
            idx = np.random.choice(len(images), size=n_samples, replace=False)
            images = images[idx]
            labels = labels[idx]
        return images, labels

    def fit(self, images: np.ndarray):
        flat = images.reshape(images.shape[0], -1)
        self.scaler = StandardScaler(with_mean=True, with_std=True)
        scaled = self.scaler.fit_transform(flat)
        self.pca = PCA(n_components=self.latent_dim, random_state=42)
        latent = self.pca.fit_transform(scaled)
        return latent

    def transform(self, images: np.ndarray):
        if self.scaler is None or self.pca is None:
            raise RuntimeError("Call fit() before transform().")
        flat = images.reshape(images.shape[0], -1)
        scaled = self.scaler.transform(flat)
        return self.pca.transform(scaled)

    def inverse_transform(self, latent: np.ndarray):
        if self.scaler is None or self.pca is None:
            raise RuntimeError("Call fit() before inverse_transform().")
        scaled = self.pca.inverse_transform(latent)
        flat = self.scaler.inverse_transform(scaled)
        return np.clip(flat.reshape(-1, 28, 28), 0.0, 1.0)


class LatentQuadraticPotential:
    """Mahalanobis energy in latent space (acts like a Gaussian well)."""

    def __init__(self, covariance: np.ndarray, temperature: float = 1.0):
        reg = 1e-6 * np.eye(covariance.shape[0])
        self.precision = np.linalg.inv(covariance + reg)
        self.temperature = temperature

    def _ensure_2d(self, X: np.ndarray):
        return X if X.ndim == 2 else X.reshape(1, -1)

    def V(self, X: np.ndarray):
        X2d = self._ensure_2d(X)
        energy = 0.5 * np.sum((X2d @ self.precision) * X2d, axis=1)
        return energy / self.temperature

    def grad_V(self, X: np.ndarray):
        X2d = self._ensure_2d(X)
        grad = (X2d @ self.precision) / self.temperature
        return grad

## Fit PCA Encoder and Inspect Latent Space

In [3]:
# Build latent representation
latent_dim = 32
mnist_pipeline = MNISTLatentPipeline(data_root="data/MNIST", latent_dim=latent_dim)
mnist_images, mnist_labels = mnist_pipeline.load_data(n_samples=50000, train=True)
latent_embeddings = mnist_pipeline.fit(mnist_images)

latent_cov = np.cov(latent_embeddings, rowvar=False)
potential = LatentQuadraticPotential(covariance=latent_cov, temperature=1.0)

print("=" * 60)
print("MNIST latent statistics:")
print("=" * 60)
print(f"Samples used: {latent_embeddings.shape[0]}")
print(f"Latent dimension: {latent_dim}")
print(f"Explained variance ratio: {mnist_pipeline.pca.explained_variance_ratio_}")
print(f"Latent mean (first 2 dims): {latent_embeddings.mean(axis=0)}")
print("=" * 60)

# Build density grid for visualization
x_min, y_min = latent_embeddings.min(axis=0) - 1.0
x_max, y_max = latent_embeddings.max(axis=0) + 1.0
latent_bounds = ((x_min, x_max), (y_min, y_max))
x_range = np.linspace(x_min, x_max, 200)
y_range = np.linspace(y_min, y_max, 200)
X_grid, Y_grid = np.meshgrid(x_range, y_range)

kde = gaussian_kde(latent_embeddings.T)
density_grid = kde(np.vstack([X_grid.ravel(), Y_grid.ravel()])).reshape(X_grid.shape)

fig_scatter, ax = plt.subplots(figsize=(7, 6))
ax.contourf(X_grid, Y_grid, density_grid, levels=20, cmap="viridis", alpha=0.7)
scatter = ax.scatter(
    latent_embeddings[:, 0], latent_embeddings[:, 1],
    c=mnist_labels[:latent_embeddings.shape[0]], cmap="tab20", s=5, alpha=0.4,
)
ax.set_xlabel("Latent dim 1")
ax.set_ylabel("Latent dim 2")
ax.set_title("MNIST latent distribution (PCA)")
ax.set_aspect("equal")
ax.grid(True, alpha=0.3)
plt.show()

# Reconstruct a few digits to verify the lift
recon_samples = mnist_pipeline.inverse_transform(latent_embeddings[:16])
fig_recon, axes = plt.subplots(4, 4, figsize=(6, 6))
for idx, ax in enumerate(axes.ravel()):
    if idx < recon_samples.shape[0]:
        ax.imshow(recon_samples[idx], cmap="gray")
    ax.axis("off")
fig_recon.suptitle("Reconstructed digits via inverse PCA", fontsize=14)
plt.tight_layout()
plt.show()

print("Decoder sanity check complete: reconstructed samples visualized.")

MNIST latent statistics:
Samples used: 50000
Latent dimension: 32
Explained variance ratio: [0.05705443 0.04146104 0.03798272 0.02925506 0.0254175  0.02220319
 0.01945319 0.01762282 0.01556257 0.01419107 0.01361732 0.01218615
 0.01130097 0.01106123 0.01047601 0.01010681 0.00945249 0.00934628
 0.00911511 0.00885225 0.00839101 0.0081518  0.00775032 0.00754129
 0.00728822 0.00703061 0.00698728 0.00675782 0.00640675 0.0062678
 0.00606485 0.00596459]
Latent mean (first 2 dims): [ 1.9847631e-07  1.0109925e-06  8.4353445e-07  4.4567108e-07
 -4.7385456e-07  1.4875793e-06  2.1265578e-06 -5.2680969e-07
 -6.1259271e-08 -4.9475074e-07  4.9034117e-07  1.6213376e-06
 -8.1505777e-07 -4.2264460e-07  2.1303331e-06  3.2160281e-07
 -1.5620851e-06 -6.6792489e-07  8.3309891e-07 -1.3632250e-06
  9.8544126e-07 -2.9323459e-07  2.2020226e-06 -3.1128405e-07
  1.4401471e-06  2.0357525e-06  6.3280106e-07  7.0935249e-07
  8.0370188e-07  4.5450329e-07  5.0237475e-07 -6.0159681e-07]


ValueError: too many values to unpack (expected 2)

## Validate Latent ↔ Pixel Consistency

In [None]:
# Verify reconstruction error statistics
n_eval = 2048
idx_eval = np.random.choice(mnist_images.shape[0], size=n_eval, replace=False)
images_eval = mnist_images[idx_eval]
latent_eval = mnist_pipeline.transform(images_eval)
recon_eval = mnist_pipeline.inverse_transform(latent_eval)

mse = np.mean((images_eval - recon_eval) ** 2)
max_err = np.max(np.abs(images_eval - recon_eval))

print("=" * 60)
print("Latent reconstruction diagnostics")
print("=" * 60)
print(f"Evaluation samples: {n_eval}")
print(f"Mean squared error: {mse:.6f}")
print(f"Max abs error: {max_err:.4f}")
print("=" * 60)

fig_compare, axes = plt.subplots(4, 4, figsize=(6, 6))
for i, ax in enumerate(axes.ravel()):
    original = images_eval[i]
    recon = recon_eval[i]
    ax.imshow(np.hstack([original, recon]), cmap="gray")
    ax.axis("off")
fig_compare.suptitle("Left: original | Right: reconstruction", fontsize=12)
plt.tight_layout()
plt.show()

## Generate Training Data: (X_tar, X_tar_next)

Sample latent codes directly from MNIST, then evolve them with Langevin dynamics in the latent space:
$$dZ = -\nabla V(Z)\,dt + \sqrt{2}\,dW$$

In [None]:
# ============================================================
# Step 1: Sample X_tar from latent MNIST distribution
# ============================================================

def sample_latent(latent_data, labels, n_samples=2500):
    idx = np.random.choice(latent_data.shape[0], size=n_samples, replace=False)
    return latent_data[idx], labels[idx], idx

n_samples = 2500
print("=" * 60)
print("Sampling X_tar from MNIST latent embeddings ...")
print("=" * 60)
X_tar, X_tar_labels, X_tar_idx = sample_latent(latent_embeddings, mnist_labels, n_samples)
print(f"X_tar shape: {X_tar.shape}")
print(f"Latent ranges: dim1 ∈ [{X_tar[:, 0].min():.2f}, {X_tar[:, 0].max():.2f}], dim2 ∈ [{X_tar[:, 1].min():.2f}, {X_tar[:, 1].max():.2f}]")
print("=" * 60)

# Define latent anchors (previously 'wells') using k-means clusters
n_latent_regions = 4
kmeans = KMeans(n_clusters=n_latent_regions, random_state=42, n_init=10)
kmeans.fit(X_tar)
well_centers = kmeans.cluster_centers_
center_dists = np.linalg.norm(X_tar[:, None, :] - well_centers[None, :, :], axis=2)
min_center_dist = center_dists.min(axis=1)
well_radius = np.percentile(min_center_dist, 40)

print("Latent anchors (derived via KMeans):")
for i, center in enumerate(well_centers):
    print(f"  Anchor {i+1}: ({center[0]:+.3f}, {center[1]:+.3f})")
print(f"Suggested anchor radius: {well_radius:.3f}")

In [None]:
# ============================================================
# Step 2: Evolve X_tar → X_tar_next using latent Langevin dynamics
# ============================================================

def langevin_step(Z, potential_model, dt, n_substeps=500):
    """Euler-Maruyama integration in latent space."""
    n = Z.shape[0]
    dt_sub = dt / n_substeps
    Z_curr = Z.copy()
    for _ in range(n_substeps):
        grad_V = potential_model.grad_V(Z_curr)
        drift = -grad_V
        noise = np.sqrt(2.0 * dt_sub) * np.random.randn(n, Z_curr.shape[1])
        Z_curr = Z_curr + drift * dt_sub + noise
    return Z_curr

# Evolution parameters
dt = 0.1
n_substeps = 500

print("=" * 60)
print("Evolving X_tar → X_tar_next via latent Langevin dynamics ...")
print("=" * 60)
print(f"SDE: dZ = -∇V(Z)dt + √(2)dW (latent space)")
print(f"Total time step: dt = {dt}")
print(f"Number of sub-steps: {n_substeps}")
print(f"Number of particles: {X_tar.shape[0]}")

X_tar_next = langevin_step(X_tar, potential, dt, n_substeps=n_substeps)

print(f"X_tar_next shape: {X_tar_next.shape}")
print(f"Latent ranges: dim1 ∈ [{X_tar_next[:, 0].min():.2f}, {X_tar_next[:, 0].max():.2f}], dim2 ∈ [{X_tar_next[:, 1].min():.2f}, {X_tar_next[:, 1].max():.2f}]")
print("=" * 60)

In [None]:
# ============================================================
# Visualize latent samples and their evolved counterparts
# ============================================================

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

for ax, data, title, color in [
    (axes[0], X_tar, "X_tar (sampled latent codes)", "tab:blue"),
    (axes[1], X_tar_next, f"X_tar_next (after dt={dt})", "tab:orange"),
]:
    ax.contourf(X_grid, Y_grid, density_grid, levels=20, cmap="viridis", alpha=0.5)
    ax.scatter(data[:, 0], data[:, 1], c=color, s=15, alpha=0.6, edgecolors="k", linewidth=0.2)
    ax.scatter(well_centers[:, 0], well_centers[:, 1], color="red", s=100, marker="*", label="Latent anchors")
    ax.set_xlabel("Latent dim 1")
    ax.set_ylabel("Latent dim 2")
    ax.set_title(title, fontsize=14, fontweight="bold")
    ax.set_xlim(latent_bounds[0])
    ax.set_ylim(latent_bounds[1])
    ax.set_aspect("equal")
    ax.legend(loc="upper right")
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("=" * 60)
print("Data generation complete!")
print("=" * 60)
print(f"Training data pairs: {X_tar.shape[0]}")
print("=" * 60)

In [None]:
# # ============================================================
# # Polynomial EDMD: Extended Dynamic Mode Decomposition
# # ============================================================

# # Preconditions
# if 'X_tar' not in globals() or 'X_tar_next' not in globals():
#     raise RuntimeError('X_tar and X_tar_next must be computed before running EDMD (polynomial).')

# n = X_tar.shape[0]

# # Step 1: Build polynomial basis functions
# degree = 4

# def monomial_exponents_2d(deg: int):
#     exps = []
#     for total in range(deg + 1):
#         for i in range(total + 1):
#             exps.append((i, total - i))
#     return exps

# exps = monomial_exponents_2d(degree)
# m_dict = len(exps)

# # Optional feature scaling for numerical stability
# Z_all = np.vstack([X_tar, X_tar_next])
# scale = 1.0  # No scaling for this example
# X0 = X_tar / scale
# Y0 = X_tar_next / scale

# # Feature map Φ(X)
# def phi_poly(X: np.ndarray) -> np.ndarray:
#     N = X.shape[0]
#     Phi = np.empty((N, m_dict), dtype=float)
#     x = X[:, 0]
#     y = X[:, 1]
#     for k, (i, j) in enumerate(exps):
#         if i == 0 and j == 0:
#             Phi[:, k] = 1.0
#         elif i == 0:
#             Phi[:, k] = y ** j
#         elif j == 0:
#             Phi[:, k] = x ** i
#         else:
#             Phi[:, k] = (x ** i) * (y ** j)
#     return Phi

# Phi = phi_poly(X0)
# Phi_next = phi_poly(Y0)
# N = Phi.shape[0]

# # Step 2: Build EDMD matrices
# G_edmd = (Phi.T @ Phi) / N
# A_edmd = (Phi.T @ Phi_next) / N

# # Step 3: Compute Koopman operator
# reg = 1e-10
# I = np.eye(G_edmd.shape[0])
# K_edmd = np.linalg.solve(G_edmd + reg * I, A_edmd)

# # Step 4: Compute eigenvalues
# eigenvalues_edmd = np.linalg.eigvals(K_edmd)

# # Step 5: Construct generator eigenvalues and inverse weights
# # Extract real part of eigenvalues (ignore imaginary part)
# lambda_ns_edmd = eigenvalues_edmd.real

# # Construct generator eigenvalues: λ_gen = (λ_K - 1) / dt
# lambda_gen_edmd = (lambda_ns_edmd - 1.0) / dt

# # Build inverse generator weights (for LAWGD)
# tol_edmd = 1e-6
# lambda_ns_inv_edmd = np.zeros_like(lambda_ns_edmd)
# mask_edmd = lambda_ns_edmd >= tol_edmd
# lambda_ns_inv_edmd[mask_edmd] = 1.0 / (lambda_ns_edmd[mask_edmd] + 0.001)

# # Store results for LAWGD
# eigvals_K_edmd = lambda_ns_edmd.copy()
# lambda_gen_full_edmd = lambda_gen_edmd.copy()

# # ============================================================
# # Visualization: Eigenvalues on Unit Circle
# # ============================================================

# fig, ax = plt.subplots(figsize=(5, 5))

# # Plot unit circle
# theta = np.linspace(0, 2*np.pi, 100)
# ax.plot(np.cos(theta), np.sin(theta), 'k--', linewidth=1, label='Unit Circle')

# # Plot eigenvalues
# ax.scatter(eigenvalues_edmd.real, eigenvalues_edmd.imag, c='red', s=50, marker='o', 
#            edgecolors='black', linewidths=1, label='Eigenvalues', zorder=3)

# # Set equal aspect ratio and labels
# ax.set_aspect('equal')
# ax.grid(True, alpha=0.3)
# ax.axhline(y=0, color='k', linewidth=0.5)
# ax.axvline(x=0, color='k', linewidth=0.5)
# ax.set_xlabel('Real', fontsize=12)
# ax.set_ylabel('Imaginary', fontsize=12)
# ax.set_title('Polynomial EDMD: Koopman Eigenvalues on Unit Circle', fontsize=14)
# ax.legend(fontsize=10)

# plt.tight_layout()
# plt.show()

# print(f"Number of eigenvalues: {len(eigenvalues_edmd)}")
# print(f"Max magnitude: {np.max(np.abs(eigenvalues_edmd)):.4f}")
# print(f"Eigenvalues outside unit circle: {np.sum(np.abs(eigenvalues_edmd) > 1)}")

# # Sort eigenvalues by real part (descending order)
# sorted_real_edmd = np.sort(eigenvalues_edmd.real)[::-1]

# print("\n" + "="*50)
# print("Eigenvalues (Real part, sorted from large to small):")
# print("="*50)
# for i, real_part in enumerate(sorted_real_edmd):
#     print(f"{i+1:3d}. {real_part:+.6f}")

In [None]:
# from deeptime.data import quadruple_well
# import matplotlib.pyplot as plt

# h = 1e-3 # step size of the Euler-Maruyama integrator
# n_steps = 10000 # number of steps, the lag time is thus tau = nSteps*h = 10
# x0 = np.zeros((1, 2)) # inital condition
# n = 10000 # number of evaluations of the  discretized dynamical system with lag time tau

# f = quadruple_well(n_steps=n_steps)  # loading the model
# traj = f.trajectory(x0, n, seed=42)

# m = 2500 # number of training data points
# X = np.random.uniform(-2, 2, size=(2500, 2)) # training data
# # X = 4*np.random.rand(2, m)-2
# Y = f(X, seed=42, n_jobs=1) # training data mapped forward by the dynamical system

# from deeptime.kernels import GaussianKernel
# from deeptime.decomposition import KernelEDMD

# # ============================================================
# # Kernel Definition (GaussianKernel)
# # ============================================================
# # Deeptime's GaussianKernel(sigma) defines:
# #   k(x, y) = exp(-||x - y||² / (2 * sigma²))
# # 
# # This is equivalent to RBF kernel with bandwidth epsilon = sigma²
# # So sigma=1 means epsilon=1
# # ============================================================

# sigma = 1  # kernel bandwidth parameter
# kernel = GaussianKernel(sigma)

# # ============================================================
# # Compute ALL eigenvalues using full eigendecomposition
# # ============================================================
# # Instead of only computing n_eigs=6, we'll compute ALL eigenvalues
# # by manually constructing the Koopman operator matrix

# def rbf_kernel_matrix(X1, X2, sigma):
#     """Compute RBF kernel matrix: k(x,y) = exp(-||x-y||²/(2*sigma²))"""
#     # Compute squared distances
#     X1_sq = np.sum(X1**2, axis=1, keepdims=True)  # (n1, 1)
#     X2_sq = np.sum(X2**2, axis=1, keepdims=True)  # (n2, 1)
#     sq_dists = X1_sq + X2_sq.T - 2 * (X1 @ X2.T)  # (n1, n2)
    
#     # Apply RBF kernel
#     K = np.exp(-sq_dists / (2 * sigma**2))
#     return K

# K_XX = rbf_kernel_matrix(X, X, sigma)  # shape: (2500, 2500)
# K_XY = rbf_kernel_matrix(X, Y, sigma)  # shape: (2500, 2500)

# # Construct Koopman operator K = K_XY @ (K_XX + epsilon*I)^{-1}
# epsilon_reg = 1e-3  # regularization parameter
# I_mat = np.eye(K_XX.shape[0])
# K_koopman = K_XY @ np.linalg.inv(K_XX + epsilon_reg * I_mat)

# # Compute ALL eigenvalues
# eigenvalues_all, eigenvectors_all = np.linalg.eig(K_koopman)

# # ============================================================
# # Visualization: Eigenvalues on Unit Circle
# # ============================================================

# fig, ax = plt.subplots(figsize=(5, 5))

# # Plot unit circle
# theta = np.linspace(0, 2*np.pi, 100)
# ax.plot(np.cos(theta), np.sin(theta), 'k--', linewidth=1, label='Unit Circle')

# # Plot eigenvalues
# ax.scatter(eigenvalues_all.real, eigenvalues_all.imag, c='red', s=50, marker='o', 
#            edgecolors='black', linewidths=1, label='Eigenvalues', zorder=3)

# # Set equal aspect ratio and labels
# ax.set_aspect('equal')
# ax.grid(True, alpha=0.3)
# ax.axhline(y=0, color='k', linewidth=0.5)
# ax.axvline(x=0, color='k', linewidth=0.5)
# ax.set_xlabel('Real', fontsize=12)
# ax.set_ylabel('Imaginary', fontsize=12)
# ax.set_title('Deeptime Kernel EDMD: Koopman Eigenvalues on Unit Circle', fontsize=14)
# ax.legend(fontsize=10)

# plt.tight_layout()
# plt.show()

# print(f"Number of eigenvalues: {len(eigenvalues_all)}")
# print(f"Max magnitude: {np.max(np.abs(eigenvalues_all)):.4f}")
# print(f"Eigenvalues outside unit circle: {np.sum(np.abs(eigenvalues_all) > 1)}")

# # Sort eigenvalues by real part (descending order)
# sorted_idx = np.argsort(np.abs(eigenvalues_all))[::-1]
# sorted_real_deeptime = np.sort(eigenvalues_all.real)[::-1]

# print("\n" + "="*50)
# print("Eigenvalues (Real part, sorted from large to small):")
# print("="*50)
# for i, real_part in enumerate(sorted_real_deeptime):
#     print(f"{i+1:3d}. {real_part:+.6f}")

In [None]:
# # ============================================================
# # Kernel EDMD: Extended Dynamic Mode Decomposition with RBF Kernel
# # ============================================================

# # Preconditions
# if 'X_tar' not in globals() or 'X_tar_next' not in globals():
#     raise RuntimeError('X_tar and X_tar_next must be computed before running Kernel EDMD.')

# n = X_tar.shape[0]

# # Step 1: Build Gaussian RBF kernel matrices
# # Compute bandwidth epsilon using median heuristic
# sq_tar = np.sum(X_tar ** 2, axis=1)
# H_tar = sq_tar[:, None] + sq_tar[None, :] - 2 * (X_tar @ X_tar.T)
# epsilon_kedmd = 0.5 * np.median(H_tar) / np.log(n + 1)

# def rbf_kernel(X, Y, eps):
#     """Gaussian RBF kernel k(x,y) = exp(-||x-y||²/(2ε))"""
#     sq_x = np.sum(X ** 2, axis=1)
#     sq_y = np.sum(Y ** 2, axis=1)
#     H = sq_x[:, None] + sq_y[None, :] - 2 * (X @ Y.T)
#     return np.exp(-H / (2 * eps))

# # Kernel matrices: K_xx = K(X_tar, X_tar), K_xy = K(X_tar, X_tar_next)
# K_xx = rbf_kernel(X_tar, X_tar, 1)
# K_xy = rbf_kernel(X_tar, X_tar_next, 1)

# # Step 2: Compute Koopman operator via kernel matrices
# # K = K_xy @ (K_xx + γI)^{-1}
# reg_kedmd = 1e-3  # Use strong regularization like deeptime
# I_kedmd = np.eye(n)
# K_kedmd = K_xy @ np.linalg.inv(K_xx + reg_kedmd * I_kedmd)

# # Step 3: Compute eigenvalues and eigenvectors
# eigenvalues_kedmd, eigenvectors_kedmd = np.linalg.eig(K_kedmd)

# # Step 4: Construct generator eigenvalues and inverse weights
# # Extract real part of eigenvalues (ignore imaginary part)
# lambda_ns_kedmd = eigenvalues_kedmd.real

# # Construct generator eigenvalues: λ_gen = (λ_K - 1) / dt
# lambda_gen_kedmd = (lambda_ns_kedmd - 1.0) / dt

# # Build inverse generator weights (for LAWGD)
# tol_kedmd = 1e-6
# lambda_ns_inv_kedmd = np.zeros_like(lambda_ns_kedmd)
# mask_kedmd = lambda_ns_kedmd >= tol_kedmd
# lambda_ns_inv_kedmd[mask_kedmd] = 1.0 / (lambda_ns_kedmd[mask_kedmd] + 0.001)

# # Store results for LAWGD
# eigvals_K_kedmd = lambda_ns_kedmd.copy()
# eigvecs_K_kedmd = eigenvectors_kedmd.copy()
# lambda_gen_full_kedmd = lambda_gen_kedmd.copy()

# # ============================================================
# # Visualization: Eigenvalues on Unit Circle
# # ============================================================

# fig, ax = plt.subplots(figsize=(5, 5))

# # Plot unit circle
# theta = np.linspace(0, 2*np.pi, 100)
# ax.plot(np.cos(theta), np.sin(theta), 'k--', linewidth=1, label='Unit Circle')

# # Plot eigenvalues
# ax.scatter(eigenvalues_kedmd.real, eigenvalues_kedmd.imag, c='red', s=50, marker='o', 
#            edgecolors='black', linewidths=1, label='Eigenvalues', zorder=3)

# # Set equal aspect ratio and labels
# ax.set_aspect('equal')
# ax.grid(True, alpha=0.3)
# ax.axhline(y=0, color='k', linewidth=0.5)
# ax.axvline(x=0, color='k', linewidth=0.5)
# ax.set_xlabel('Real', fontsize=12)
# ax.set_ylabel('Imaginary', fontsize=12)
# ax.set_title('Kernel EDMD: Koopman Eigenvalues on Unit Circle', fontsize=14)
# ax.legend(fontsize=10)

# plt.tight_layout()
# plt.show()

# print(f"Number of eigenvalues: {len(eigenvalues_kedmd)}")
# print(f"Max magnitude: {np.max(np.abs(eigenvalues_kedmd)):.4f}")
# print(f"Eigenvalues outside unit circle: {np.sum(np.abs(eigenvalues_kedmd) > 1)}")

# # Sort eigenvalues by real part (descending order)
# sorted_real_kedmd = np.sort(eigenvalues_kedmd.real)[::-1]

# print("\n" + "="*50)
# print("Eigenvalues (Real part, sorted from large to small):")
# print("="*50)
# for i, real_part in enumerate(sorted_real_kedmd):
#     print(f"{i+1:3d}. {real_part:+.6f}")

In [None]:
# ============================================================
# DM Method: Diffusion Maps for Langevin Generator Construction
# ============================================================

# Preconditions
if 'X_tar' not in globals() or 'X_tar_next' not in globals():
    raise RuntimeError('X_tar and X_tar_next must be computed before running DM method.')

n = X_tar.shape[0]

# Step 1: Build Gaussian kernel
sq_tar = np.sum(X_tar ** 2, axis=1)
H = sq_tar[:, None] + sq_tar[None, :] - 2 * (X_tar @ X_tar.T)
# epsilon = 0.5 * np.median(H) / np.log(n + 1)
epsilon = 2 * dt

def ker(X):
    """Gaussian kernel k(x,y) = exp(-||x-y||²/(2ε))"""
    sq = np.sum(X ** 2, axis=1)
    return np.exp(-(sq[:, None] + sq[None, :] - 2 * (X @ X.T)) / (2 * epsilon))

data_kernel = ker(X_tar)

# Step 2: Anisotropic normalization
p_x = np.sqrt(np.sum(data_kernel, axis=1))
p_y = p_x.copy()
data_kernel_norm = data_kernel / p_x[:, None] / p_y[None, :]
D_y = np.sum(data_kernel_norm, axis=0)

# Step 3: Random-walk symmetric normalization
# rw_kernel = 0.5 * (data_kernel_norm / D_y + data_kernel_norm / D_y[:, None])
rw_kernel = data_kernel_norm / D_y[:, None]

# Step 4: SVD to get spectrum
phi, s, _ = svd(rw_kernel)
lambda_ns = s  # Singular values (eigenvalues of symmetric matrix)

# Step 5: Construct generator eigenvalues
lambda_gen_dm = (lambda_ns - 1.0) / epsilon

# Step 6: Build inverse generator weights (for LAWGD)
tol = 1e-6
lambda_ns_inv = np.zeros_like(lambda_ns)
mask = lambda_ns >= tol
lambda_ns_inv[mask] = epsilon / (lambda_ns[mask] + 0.001)

# Store results for LAWGD
eigvals_K_dm = lambda_ns.copy()
eigvecs_K_dm = phi.copy()
lambda_gen_full = lambda_gen_dm.copy()

# ============================================================
# Visualization: Eigenvalues on Unit Circle
# ============================================================

fig, ax = plt.subplots(figsize=(5, 5))

# Plot unit circle
theta = np.linspace(0, 2*np.pi, 100)
ax.plot(np.cos(theta), np.sin(theta), 'k--', linewidth=1, label='Unit Circle')

# Plot eigenvalues
ax.scatter(eigvals_K_dm.real, eigvals_K_dm.imag, c='red', s=50, marker='o', 
           edgecolors='black', linewidths=1, label='Eigenvalues', zorder=3)

# Set equal aspect ratio and labels
ax.set_aspect('equal')
ax.grid(True, alpha=0.3)
ax.axhline(y=0, color='k', linewidth=0.5)
ax.axvline(x=0, color='k', linewidth=0.5)
ax.set_xlabel('Real', fontsize=12)
ax.set_ylabel('Imaginary', fontsize=12)
ax.set_title('DM Method: Koopman Eigenvalues on Unit Circle', fontsize=14)
ax.legend(fontsize=10)

plt.tight_layout()
plt.show()

print(f"Number of eigenvalues: {len(eigvals_K_dm)}")
print(f"Max magnitude: {np.max(np.abs(eigvals_K_dm)):.4f}")
print(f"Eigenvalues outside unit circle: {np.sum(np.abs(eigvals_K_dm) > 1)}")

# Sort eigenvalues by real part (descending order)
sorted_real_dm = np.sort(eigvals_K_dm.real)[::-1]

print("\n" + "="*50)
print("Eigenvalues (Real part, sorted from large to small):")
print("="*50)
for i, real_part in enumerate(sorted_real_dm):
    print(f"{i+1:3d}. {real_part:+.6f}")


In [None]:
# ============================================================
# LAWGD: Langevin-Adjusted Wasserstein Gradient Descent (DMPS)
# ============================================================

from matplotlib.patches import Circle

print("="*60)
print("Step 1: Generate initial particles outside latent anchors ...")
print("="*60)

n_init_particles = 500

def sample_particles_outside(target, radius_scale=1.1, batch_size=2000):
    collected = []
    total_needed = target
    radius = well_radius * radius_scale
    while sum(chunk.shape[0] for chunk in collected) < total_needed:
        batch, _, _ = sample_latent(latent_embeddings, mnist_labels, n_samples=batch_size)
        dist = np.linalg.norm(batch[:, None, :] - well_centers[None, :, :], axis=2)
        mask = dist.min(axis=1) > radius
        if np.any(mask):
            collected.append(batch[mask])
    stacked = np.vstack(collected)
    return stacked[:total_needed]

X_lawgd_init = sample_particles_outside(n_init_particles)
print(f"Initial particles (outside anchors): {X_lawgd_init.shape[0]}")
print(f"Anchor radius threshold: {well_radius:.3f}")

# Step 2: Prepare LAWGD using DM-based Koopman spectrum
print(f"\n{'='*60}")
print("Step 2: Preparing LAWGD with DM Koopman spectrum...")
print(f"{'='*60}")

# Use DM method results (computed in previous cell)
eigvals_K = eigvals_K_dm.copy()
eigvecs_K = eigvecs_K_dm.copy()
lambda_gen = lambda_gen_full.copy()

# Mode selection strategy: use all valid modes
n_skip = 1  # Skip first 1 eigenvalues (constant modes)
eig_threshold = 0.01  # Threshold for valid eigenvalues (exclude near-zero eigenvalues)

# Find valid eigenvalues: skip first n_skip, and keep those > threshold
eigvals_after_skip = eigvals_K[n_skip:]
valid_mask = eigvals_after_skip.real > eig_threshold
n_valid = np.sum(valid_mask)

# Use ALL valid modes
k_modes = n_valid
mode_start = n_skip
mode_end = mode_start + k_modes

print(f"\n  Mode selection strategy: use all valid modes")
print(f"    - Total eigenvalues: {len(eigvals_K)}")
print(f"    - Skipping first {n_skip} eigenvalues (constant modes)")
print(f"    - Eigenvalue threshold: λ > {eig_threshold} (exclude near-zero)")
print(f"    - Skipped due to threshold: {len(eigvals_K) - n_skip - k_modes}")
print(f"    - **Using {k_modes} valid modes** (modes {mode_start+1} to {mode_end})")
print(f"\n  Selected eigenvalues for LAWGD:")
for i in range(mode_start, mode_end):
    print(f"    Mode {i+1}: λ = {eigvals_K[i].real:.6f}, λ_gen = {lambda_gen[i].real:.2f}")

# Build inverse generator weights for selected modes
lambda_gen_selected = lambda_gen[mode_start:mode_end]
tol_gen = 1e-6
lambda_gen_inv_selected = np.zeros(k_modes, dtype=complex)
mask_nonzero = np.abs(lambda_gen_selected) > tol_gen
lambda_gen_inv_selected[mask_nonzero] = 1.0 / lambda_gen_selected[mask_nonzero]

print(f"  Generator eigenvalue range: [{lambda_gen_selected.real.min():.2f}, {lambda_gen_selected.real.max():.2f}]")

# Step 4: Helper functions for kernel evaluation and gradient
def evaluate_kernel_at_points(X_query, X_data):
    """
    Evaluate normalized kernel k(x_query, x_data) for all pairs
    Returns: (n_query, n_data) kernel matrix
    """
    sq_query = np.sum(X_query ** 2, axis=1)
    sq_data = np.sum(X_data ** 2, axis=1)
    H = sq_query[:, None] + sq_data[None, :] - 2 * (X_query @ X_data.T)
    K_raw = np.exp(-H / (2 * epsilon))
    
    # Apply anisotropic normalization (using p_x from training data)
    # For query points, estimate density using kernel with training data
    p_query = np.sqrt(np.sum(K_raw, axis=1))
    K_norm = K_raw / p_query[:, None] / p_x[None, :]
    
    # Apply random-walk normalization (using D_y from training)
    # K_rw[i,j] = 0.5 * (K_norm[i,j] / D_y + K_norm[j,i] / D_y[i])
    # Since K_norm is (n_query, n_data) and not symmetric, we simplify:
    K_rw = K_norm / D_y[None, :]  # Divide each column by D_y
    
    return K_rw

def compute_kernel_gradient(X_query, X_data):
    """
    Compute gradient of kernel ∇_x k(x, y) w.r.t. x (query points)
    Returns: (n_query, n_data, 2) array where [:, :, d] is ∂k/∂x_d
    """
    n_query = X_query.shape[0]
    n_data = X_data.shape[0]
    
    # Compute pairwise differences: X_query - X_data
    # Shape: (n_query, n_data, 2)
    diff = X_query[:, None, :] - X_data[None, :, :]  # Broadcasting
    
    # Compute base kernel
    sq_query = np.sum(X_query ** 2, axis=1)
    sq_data = np.sum(X_data ** 2, axis=1)
    H = sq_query[:, None] + sq_data[None, :] - 2 * (X_query @ X_data.T)
    K_raw = np.exp(-H / (2 * epsilon))
    
    # Gradient of Gaussian kernel: ∂k/∂x = -k(x,y) * (x-y) / ε
    # Shape: (n_query, n_data, 2)
    grad_K_raw = -K_raw[:, :, None] * diff / epsilon
    
    # For simplicity, apply normalization to gradient (approximate)
    # This is a simplified version; full gradient would include normalization terms
    p_query = np.sqrt(np.sum(K_raw, axis=1, keepdims=True))
    grad_K_norm = grad_K_raw / p_query[:, :, None] / p_x[None, :, None]
    
    return grad_K_norm

# Step 5: LAWGD iteration with DM kernel
n_particles = X_lawgd_init.shape[0]
n_iter_lawgd = 1000
h_lawgd = 1  # Step size

X_lawgd_traj = np.zeros((n_particles, 2, n_iter_lawgd))
X_lawgd_traj[:, :, 0] = X_lawgd_init.copy()

print(f"\n{'='*60}")
print("Step 3: Running LAWGD iterations (DM kernel-based)...")
print(f"{'='*60}")
print(f"Particles: {n_particles}")
print(f"Iterations: {n_iter_lawgd}")
print(f"Step size: {h_lawgd}")
print(f"Active modes: {k_modes}")

for t in range(n_iter_lawgd - 1):
    X_curr = X_lawgd_traj[:, :, t]
    
    # Evaluate kernel between current particles and training data
    K_curr = evaluate_kernel_at_points(X_curr, X_tar)  # (n_particles, n_data)
    
    # Project onto eigenmodes: c = Φ^T @ K(x, X_tar)
    # eigvecs_K has shape (n_data, n_data), we use selected modes
    eigvecs_selected = eigvecs_K[:, mode_start:mode_end]  # (n_data, k_modes)
    c = eigvecs_selected.T @ K_curr.T  # (k_modes, n_particles)
    
    # Apply inverse generator: c_inv = Λ_gen^{-1} @ c
    c_inv = lambda_gen_inv_selected[:, None] * c  # (k_modes, n_particles)
    
    # Project back to data space: f = Φ @ c_inv
    f_inv = eigvecs_selected @ c_inv  # (n_data, n_particles)
    
    # Compute gradient using kernel gradient
    grad_K = compute_kernel_gradient(X_curr, X_tar)  # (n_particles, n_data, 2)
    
    # Gradient update: ∇_x f = Σ_i f_inv[i] * ∇_x k(x, x_i)
    grad_update = np.zeros((n_particles, 2))
    for d_idx in range(2):
        grad_update[:, d_idx] = np.sum(
            grad_K[:, :, d_idx] * f_inv.T.real,  # Use real part
            axis=1
        )
    
    # Gradient descent step
    X_lawgd_traj[:, :, t+1] = X_curr - h_lawgd * grad_update
    
    if (t+1) % 100 == 0 or t == 0:
        print(f"\r  [DM] Iteration {t+1}/{n_iter_lawgd-1}  ", end='', flush=True)

print()  # Print newline after loop
print("LAWGD iteration complete!")

# Step 6: Compute final metrics
dist_final_to_wells = np.array([np.linalg.norm(X_lawgd_traj[:, :, -1] - center, axis=1) for center in well_centers])
min_dist_final = np.min(dist_final_to_wells, axis=0)
in_well_final = min_dist_final <= well_radius

print(f"\n{'='*60}")
print("LAWGD Results:")
print(f"{'='*60}")
print(f"Initial particles (outside anchors): {n_particles}")
print(f"Final particles near anchors: {np.sum(in_well_final)} ({100*np.sum(in_well_final)/n_particles:.1f}%)")
print(f"Final particles still outside: {n_particles - np.sum(in_well_final)} ({100*(n_particles - np.sum(in_well_final))/n_particles:.1f}%)")
print(f"{'='*60}")

In [None]:
# ============================================================
# Visualization 1: Initial vs Final Positions (DMPS)
# ============================================================

fig, ax = plt.subplots(figsize=(8, 7))

ax.contourf(X_grid, Y_grid, density_grid, levels=30, cmap='Blues', alpha=0.7)
ax.scatter(X_tar[:, 0], X_tar[:, 1], s=3, c='lightgray', alpha=0.2)
ax.scatter(X_lawgd_traj[:, 0, 0], X_lawgd_traj[:, 1, 0],
           s=25, c='red', marker='o', label='Initial (outside anchors)', zorder=5)
ax.scatter(X_lawgd_traj[:, 0, -1], X_lawgd_traj[:, 1, -1],
           s=35, facecolors='none', edgecolors='magenta', linewidth=1.5,
           label='Final', zorder=15)

# Draw anchor boundaries
for idx, center in enumerate(well_centers):
    circle = Circle(center, well_radius, fill=False, edgecolor='green',
                   linewidth=2, linestyle='--')
    ax.add_patch(circle)

ax.scatter(well_centers[:, 0], well_centers[:, 1], s=100, c='green',
           marker='*', zorder=10)
ax.set_xlabel('Latent dim 1', fontsize=12)
ax.set_ylabel('Latent dim 2', fontsize=12)
ax.set_title('DMPS: Initial vs Final Positions', fontsize=14, fontweight='bold')
ax.set_xlim(latent_bounds[0])
ax.set_ylim(latent_bounds[1])
ax.set_aspect('equal')
ax.legend(loc='upper right', fontsize=9)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# ============================================================
# Visualization 2: Trajectories (DMPS)
# ============================================================

fig, ax = plt.subplots(figsize=(8, 7))

ax.contourf(X_grid, Y_grid, density_grid, levels=30, cmap='Blues', alpha=0.7)

n_show_traj = min(15, n_particles)
for i in range(n_show_traj):
    ax.plot(X_lawgd_traj[i, 0, :], X_lawgd_traj[i, 1, :],
            alpha=0.5, linewidth=1, color='gray')

ax.scatter(X_lawgd_traj[:n_show_traj, 0, 0], X_lawgd_traj[:n_show_traj, 1, 0],
           s=40, c='red', marker='o', zorder=5, label='Start')
ax.scatter(X_lawgd_traj[:n_show_traj, 0, -1], X_lawgd_traj[:n_show_traj, 1, -1],
           s=50, facecolors='none', edgecolors='magenta',
           linewidth=1.5, zorder=15, label='End')

for idx, center in enumerate(well_centers):
    circle = Circle(center, well_radius, fill=False, edgecolor='green',
                   linewidth=2, linestyle='--')
    ax.add_patch(circle)

ax.scatter(well_centers[:, 0], well_centers[:, 1], s=100, c='green',
           marker='*', zorder=10)
ax.set_xlabel('Latent dim 1', fontsize=12)
ax.set_ylabel('Latent dim 2', fontsize=12)
ax.set_title(f'DMPS Trajectories (first {n_show_traj} particles)',
             fontsize=14, fontweight='bold')
ax.set_xlim(latent_bounds[0])
ax.set_ylim(latent_bounds[1])
ax.set_aspect('equal')
ax.legend(loc='upper right', fontsize=10)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Convergence Analysis: DMPS with Metrics Tracking (IID Data)

Run DMPS for 1000 steps and track convergence metrics every 100 steps

In [None]:
# ============================================================
# DMPS with Metrics Tracking (0-1000 steps, every 100 steps)
# ============================================================

from matplotlib.patches import Circle

print("="*60)
print("DMPS: Running 1000 iterations with metrics tracking...")
print("="*60)

# Step 1: Setup parameters
n_particles = X_lawgd_init.shape[0]
n_iter_lawgd = 1000
h_lawgd = 1  # Step size
record_interval = 100  # Record metrics every 100 steps
n_records = n_iter_lawgd // record_interval + 1  # Include step 0

# Initialize trajectory storage
X_lawgd_traj_dm = np.zeros((n_particles, 2, n_iter_lawgd + 1))
X_lawgd_traj_dm[:, :, 0] = X_lawgd_init.copy()

# Initialize metrics storage
metrics_dm = {
    'steps': [],
    'well_coverage': [],  # Percentage of particles in wells
    'avg_potential': [],  # Average potential energy
    'movement_rate': []   # Average displacement per step
}

print(f"Particles: {n_particles}")
print(f"Iterations: {n_iter_lawgd}")
print(f"Recording interval: {record_interval} steps")
print(f"Total records: {n_records}")

# Helper function to compute metrics
def compute_metrics(X_current, X_previous=None):
    """Compute convergence metrics for current particle positions"""
    # 1. Well coverage: percentage in wells
    dist_to_wells = np.array([np.linalg.norm(X_current - center, axis=1) 
                              for center in well_centers])
    min_dist = np.min(dist_to_wells, axis=0)
    in_well = min_dist <= well_radius
    coverage = 100.0 * np.sum(in_well) / len(X_current)
    
    # 2. Average potential energy
    avg_V = np.mean(potential.V(X_current))
    
    # 3. Movement rate (only if previous positions available)
    if X_previous is not None:
        displacement = np.linalg.norm(X_current - X_previous, axis=1)
        movement = np.mean(displacement)
    else:
        movement = 0.0
    
    return coverage, avg_V, movement

# Record initial state (step 0)
coverage_0, avg_V_0, _ = compute_metrics(X_lawgd_traj_dm[:, :, 0])
metrics_dm['steps'].append(0)
metrics_dm['well_coverage'].append(coverage_0)
metrics_dm['avg_potential'].append(avg_V_0)
metrics_dm['movement_rate'].append(0.0)

print(f"\nInitial state (step 0):")
print(f"  Well coverage: {coverage_0:.2f}%")
print(f"  Avg potential: {avg_V_0:.4f}")

# Step 2: Run LAWGD iterations with DM kernel
print(f"\n{'='*60}")
print("Running DMPS iterations...")
print(f"{'='*60}")

for t in range(n_iter_lawgd):
    X_curr = X_lawgd_traj_dm[:, :, t]
    
    # Evaluate kernel between current particles and training data
    K_curr = evaluate_kernel_at_points(X_curr, X_tar)  # (n_particles, n_data)
    
    # Project onto eigenmodes: c = Φ^T @ K(x, X_tar)
    eigvecs_selected = eigvecs_K[:, mode_start:mode_end]  # (n_data, k_modes)
    c = eigvecs_selected.T @ K_curr.T  # (k_modes, n_particles)
    
    # Apply inverse generator: c_inv = Λ_gen^{-1} @ c
    c_inv = lambda_gen_inv_selected[:, None] * c  # (k_modes, n_particles)
    
    # Project back to data space: f = Φ @ c_inv
    f_inv = eigvecs_selected @ c_inv  # (n_data, n_particles)
    
    # Compute gradient using kernel gradient
    grad_K = compute_kernel_gradient(X_curr, X_tar)  # (n_particles, n_data, 2)
    
    # Gradient update: ∇_x f = Σ_i f_inv[i] * ∇_x k(x, x_i)
    grad_update = np.zeros((n_particles, 2))
    for d_idx in range(2):
        grad_update[:, d_idx] = np.sum(
            grad_K[:, :, d_idx] * f_inv.T.real,  # Use real part
            axis=1
        )
    
    # Gradient descent step
    X_lawgd_traj_dm[:, :, t+1] = X_curr - h_lawgd * grad_update
    
    # Record metrics every record_interval steps
    if (t + 1) % record_interval == 0:
        X_prev = X_lawgd_traj_dm[:, :, t]
        X_next = X_lawgd_traj_dm[:, :, t+1]
        coverage, avg_V, movement = compute_metrics(X_next, X_prev)
        
        metrics_dm['steps'].append(t + 1)
        metrics_dm['well_coverage'].append(coverage)
        metrics_dm['avg_potential'].append(avg_V)
        metrics_dm['movement_rate'].append(movement)
    
    # Progress indicator
    if (t+1) % 200 == 0 or t == 0:
        print(f"\r  [DM] Iteration {t+1}/{n_iter_lawgd}  ", end='', flush=True)

print()  # Newline after loop
print("DMPS iteration complete!")

# Step 3: Final summary
final_coverage = metrics_dm['well_coverage'][-1]
final_potential = metrics_dm['avg_potential'][-1]
final_movement = metrics_dm['movement_rate'][-1]

print(f"\n{'='*60}")
print("DMPS Final Results:")
print(f"{'='*60}")
print(f"Final well coverage: {final_coverage:.2f}%")
print(f"Final avg potential: {final_potential:.4f}")
print(f"Final movement rate: {final_movement:.6f}")
print(f"Total records: {len(metrics_dm['steps'])}")
print(f"{'='*60}")

# Convert metrics to numpy arrays for easier plotting
for key in ['steps', 'well_coverage', 'avg_potential', 'movement_rate']:
    metrics_dm[key] = np.array(metrics_dm[key])

print("\n✓ DMPS metrics saved in 'metrics_dm' dictionary")

In [None]:
# ============================================================
# SDMD: Stochastic Dynamic Mode Decomposition with Dictionary Learning
# ============================================================

import torch
torch.cuda.empty_cache()
# import torch.nn.functional as F
# from numpy import linalg as la
# from numpy import *
from solver_sdmd_torch_gpu import KoopmanNNTorch, KoopmanSolverTorch


print (torch.__version__, torch.cuda.is_available())
print(torch.version.cuda)
print (torch.cuda.get_device_name())
# device= 'cpu'
device= 'cuda'

# Reshape data_X and data_Y into a single column
X = X_tar  # 2D latent features
Y = X_tar_next  # 2D latent targets
print(f"Shape of X: {X.shape}")
print(f"Shape of Y: {Y.shape}")

# Separate data into two parts: train and validation
len_all = X.shape[0]
data_x_train = X[:int(0.7*len_all)]
data_x_valid = X[int(0.7*len_all)+1:]

data_y_train = Y[:int(0.7*len_all)]
data_y_valid = Y[int(0.7*len_all)+1:]

data_train = [data_x_train, data_y_train]
data_valid = [data_x_valid, data_y_valid]

print(data_x_train.shape)



#### SDMD Test ####
checkpoint_file= 'well2d_example_ckpt001.torch'

basis_function = KoopmanNNTorch(input_size= 2, layer_sizes=[10], n_psi_train=7).to(device)  # basis number would be 20


solver = KoopmanSolverTorch(dic=basis_function, # Replace 'koopman_nn' by 'dic' if you use the original solver_edmdvar
                       target_dim=np.shape(data_x_train)[-1],
                                                   reg=0.1,  checkpoint_file= checkpoint_file, fnn_checkpoint_file= 'example_fnn001.torch', 
                            a_b_file= 'a_b_example_3ple_well.jbl', 
                        generator_batch_size= 2, fnn_batch_size= 32, delta_t= dt)

solver.build_with_generator(
    data_train=data_train,
    data_valid=data_valid,
    epochs=6,
    batch_size=256,
    lr=1e-5,
    log_interval=10,
    lr_decay_factor=.8
    )

# Results from solver_edmd/solver_resdmd
evalues_sdmd = solver.eigenvalues.T
efuns_sdmd = solver.eigenfunctions(X)
evectors_sdmd = solver.eigenvectors.T
# kpm_modes = solver.compute_mode().T
N_dict_sdmd = np.shape(evalues_sdmd)[0]
Psi_X_sdmd = solver.get_Psi_X()
Psi_Y_sdmd = solver.get_Psi_Y()
Koopman_matrix_K_sdmd = solver.K

outputs_sdmd = {
    'efuns': efuns_sdmd,
    'evalues': evalues_sdmd,
    'evectors': evectors_sdmd,
    # 'kpm_modes': kpm_modes,
    'N_dict': N_dict_sdmd,
    'K': Koopman_matrix_K_sdmd,
    # 'Psi_X': Psi_X_sdmd,
    # 'Psi_Y': Psi_Y_sdmd,
    }


# ============================================================
# Visualization: Eigenvalues on Unit Circle
# ============================================================

# Assuming 'efuns_sdmd' is a 2D numpy array with shape (n_samples, n_eigenfunctions)
# and 'X' is a 2D numpy array with shape (n_samples, 2) representing your input data

# Assuming evalues_sdmd is a numpy array of complex numbers
print("SDMD eigenvalues shape", evalues_sdmd.shape)
print("SDMD eigenvalues", evalues_sdmd)
print("SDMD eigenvectors shape", evectors_sdmd.shape)
print("SDMD eigenvectors", evectors_sdmd)

# Plot eigenvalues on unit circle
real_parts = evalues_sdmd.real
imag_parts = evalues_sdmd.imag

# Create the plot
plt.figure(figsize=(6, 4))
plt.scatter(real_parts, imag_parts, color='blue', label='Eigenvalues')

# Draw a unit circle for reference
theta = np.linspace(0, 2 * np.pi, 100)
plt.plot(np.cos(theta), np.sin(theta), linestyle='--', color='grey', label='Unit Circle')

plt.title('SDMD')
plt.xlabel('Real Part')
plt.ylabel('Imaginary Part')
plt.axis('equal')  # Ensure the aspect ratio is equal to make the unit circle round
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
plt.legend()
plt.show()

# Plot eigenfunctions
fig, axs = plt.subplots(1, 4, figsize=(9, 2))

# Plot for the 1st eigenfunction
# scatter1 = axs[0].scatter(*X.T, c=np.real(efuns_sdmd)[:, 0], cmap='coolwarm')
scatter1 = axs[0].scatter(*X.T, c=np.real(efuns_sdmd)[:, 0], cmap='coolwarm', vmin=1.091, vmax=1.093)
axs[0].set_title('1st eigenfunction')
axs[0].set_xlim(latent_bounds[0])
axs[0].set_ylim(latent_bounds[1])
cbar1 = fig.colorbar(scatter1, ax=axs[0], shrink=0.7, aspect=20)

# Plot for the 2nd eigenfunction
scatter2 = axs[1].scatter(*X.T, c=np.real(efuns_sdmd)[:, 1], cmap='coolwarm')
axs[1].set_title('2nd eigenfunction')
axs[1].set_xlim(latent_bounds[0])
axs[1].set_ylim(latent_bounds[1])
cbar2 = fig.colorbar(scatter2, ax=axs[1])

# Plot for the 3rd eigenfunction
scatter3 = axs[2].scatter(*X.T, c=np.real(efuns_sdmd)[:, 2], cmap='coolwarm')
axs[2].set_title('3rd eigenfunction')
axs[2].set_xlim(latent_bounds[0])
axs[2].set_ylim(latent_bounds[1])
cbar3 = fig.colorbar(scatter3, ax=axs[2])

# Plot for the 4th eigenfunction
scatter4 = axs[3].scatter(*X.T, c=np.real(efuns_sdmd)[:, 3], cmap='coolwarm')
axs[3].set_title('4th eigenfunction')
axs[3].set_xlim(latent_bounds[0])
axs[3].set_ylim(latent_bounds[1])
cbar4 = fig.colorbar(scatter4, ax=axs[3])

fig.suptitle('SDMD', fontsize=16)
plt.tight_layout()
plt.show()

# ============================================================
# Construct Generator Eigenvalues and Inverse Weights (for LAWGD)
# ============================================================
print("\n" + "="*60)
print("Constructing Generator Inverse for LAWGD...")
print("="*60)

# Extract real part of eigenvalues (SDMD eigenvalues are already computed)
lambda_ns_sdmd = evalues_sdmd.real

# Construct generator eigenvalues: λ_gen = (λ_K - 1) / dt
lambda_gen_sdmd = (lambda_ns_sdmd - 1.0) / dt

# Build inverse generator weights (for LAWGD)
tol_sdmd = 1e-6
lambda_ns_inv_sdmd = np.zeros_like(lambda_ns_sdmd)
mask_sdmd = lambda_ns_sdmd >= tol_sdmd
lambda_ns_inv_sdmd[mask_sdmd] = dt / (lambda_ns_sdmd[mask_sdmd] + 0.001)

# Store results for LAWGD (following DM method naming convention)
eigvals_K_sdmd = lambda_ns_sdmd.copy()
eigvecs_K_sdmd = efuns_sdmd.copy()  # ✓ Use eigenfunctions (values on data points)!
lambda_gen_full_sdmd = lambda_gen_sdmd.copy()

print(f"  - eigvals_K_sdmd shape: {eigvals_K_sdmd.shape}")
print(f"  - eigvecs_K_sdmd shape: {eigvecs_K_sdmd.shape}")
print("="*60)

In [None]:
# ============================================================
# LAWGD: Langevin-Adjusted Wasserstein Gradient Descent (SDMD)
# ============================================================

from matplotlib.patches import Circle

print("="*60)
print("LAWGD with SDMD Koopman Spectrum")
print("="*60)
print("Strategy: Use SDMD eigenvalues/eigenvectors + DM kernel/normalization")
print("="*60)

# Step 1: Reuse initial particles from DMPS (or generate new ones)
if 'X_lawgd_init' not in globals():
    print("\nGenerating initial particles outside anchors ...")
    n_init_particles = 500
    X_lawgd_init = sample_particles_outside(n_init_particles)
    print(f"Particles outside anchors: {X_lawgd_init.shape[0]}")
else:
    print(f"\nReusing existing initial particles: {X_lawgd_init.shape[0]}")

# Step 2: Prepare LAWGD using SDMD-based Koopman spectrum
print(f"\n{'='*60}")
print("Step 2: Preparing LAWGD with SDMD Koopman spectrum...")
print(f"{'='*60}")

# Use SDMD method results (computed in Cell 13)
eigvals_K = eigvals_K_sdmd.copy()
eigvecs_K = eigvecs_K_sdmd.copy()
lambda_gen = lambda_gen_full_sdmd.copy()

print(f"SDMD eigenvalues shape: {eigvals_K.shape}")
print(f"SDMD eigenvectors shape: {eigvecs_K.shape}")

# Mode selection strategy: use all valid modes
n_skip = 1  # Skip first 1 eigenvalues (constant modes)
eig_threshold = 0.01  # Threshold for valid eigenvalues (exclude near-zero eigenvalues)

# Find valid eigenvalues: skip first n_skip, and keep those > threshold
eigvals_after_skip = eigvals_K[n_skip:]
valid_mask = eigvals_after_skip.real > eig_threshold
n_valid = np.sum(valid_mask)

# Use ALL valid modes
k_modes = n_valid
mode_start = n_skip
mode_end = mode_start + k_modes

print(f"\n  Mode selection strategy: use all valid modes")
print(f"    - Total eigenvalues: {len(eigvals_K)}")
print(f"    - Skipping first {n_skip} eigenvalues (constant modes)")
print(f"    - Eigenvalue threshold: λ > {eig_threshold} (exclude near-zero)")
print(f"    - Skipped due to threshold: {len(eigvals_K) - n_skip - k_modes}")
print(f"    - **Using {k_modes} valid modes** (modes {mode_start+1} to {mode_end})")
print(f"\n  Selected eigenvalues for LAWGD:")
for i in range(mode_start, min(mode_end, mode_start + 10)):
    print(f"    Mode {i+1}: λ = {eigvals_K[i].real:.6f}, λ_gen = {lambda_gen[i].real:.2f}")
if mode_end > mode_start + 10:
    print(f"    ... ({mode_end - mode_start - 10} more modes)")

# Build inverse generator weights for selected modes
lambda_gen_selected = lambda_gen[mode_start:mode_end]
tol_gen = 1e-6
lambda_gen_inv_selected = np.zeros(k_modes, dtype=complex)
mask_nonzero = np.abs(lambda_gen_selected) > tol_gen
lambda_gen_inv_selected[mask_nonzero] = 1.0 / lambda_gen_selected[mask_nonzero]

print(f"\n  Generator eigenvalue range: [{lambda_gen_selected.real.min():.2f}, {lambda_gen_selected.real.max():.2f}]")

# Step 3: Reuse DM's kernel functions and normalization parameters
print(f"\n{'='*60}")
print("Step 3: Reusing DM kernel functions and normalization...")
print(f"{'='*60}")
print(f"  - Kernel bandwidth (epsilon): {epsilon:.6f}")
print(f"  - Normalization parameters: p_x, D_y (from DM method)")
print(f"  - Kernel functions: evaluate_kernel_at_points(), compute_kernel_gradient()")

if 'epsilon' not in globals() or 'p_x' not in globals() or 'D_y' not in globals():
    raise RuntimeError("DM normalization parameters not found! Please run Cell 18 (DM method) first.")

# Step 4: LAWGD iteration with SDMD spectrum + DM kernel
n_particles = X_lawgd_init.shape[0]
n_iter_lawgd = 1000
h_lawgd = 1  # Step size

X_lawgd_traj_sdmd = np.zeros((n_particles, 2, n_iter_lawgd))
X_lawgd_traj_sdmd[:, :, 0] = X_lawgd_init.copy()

print(f"\n{'='*60}")
print("Step 4: Running LAWGD iterations (SDMD spectrum + DM kernel)...")
print(f"{'='*60}")
print(f"Particles: {n_particles}")
print(f"Iterations: {n_iter_lawgd}")
print(f"Step size: {h_lawgd}")
print(f"Active modes: {k_modes}")

for t in range(n_iter_lawgd - 1):
    X_curr = X_lawgd_traj_sdmd[:, :, t]
    
    # Evaluate kernel between current particles and training data (DM kernel)
    K_curr = evaluate_kernel_at_points(X_curr, X_tar)  # (n_particles, n_data)
    
    # Project onto SDMD eigenmodes: c = Φ_SDMD^T @ K(x, X_tar)
    eigvecs_selected = eigvecs_K[:, mode_start:mode_end]  # (n_data, k_modes)
    c = eigvecs_selected.T @ K_curr.T  # (k_modes, n_particles)
    
    # Apply inverse generator: c_inv = Λ_gen^{-1} @ c
    c_inv = lambda_gen_inv_selected[:, None] * c  # (k_modes, n_particles)
    
    # Project back to data space: f = Φ_SDMD @ c_inv
    f_inv = eigvecs_selected @ c_inv  # (n_data, n_particles)
    
    # Compute gradient using DM kernel gradient
    grad_K = compute_kernel_gradient(X_curr, X_tar)  # (n_particles, n_data, 2)
    
    # Gradient update: ∇_x f = Σ_i f_inv[i] * ∇_x k(x, x_i)
    grad_update = np.zeros((n_particles, 2))
    for d_idx in range(2):
        grad_update[:, d_idx] = np.sum(
            grad_K[:, :, d_idx] * f_inv.T.real,  # Use real part
            axis=1
        )
    
    # Gradient descent step
    X_lawgd_traj_sdmd[:, :, t+1] = X_curr - h_lawgd * grad_update
    
    if (t+1) % 100 == 0 or t == 0:
        print(f"\r  [SDMD] Iteration {t+1}/{n_iter_lawgd-1}  ", end='', flush=True)

print()  # Print newline after loop
print("LAWGD iteration complete!")

# Step 5: Compute final metrics
dist_final_to_wells = np.array([np.linalg.norm(X_lawgd_traj_sdmd[:, :, -1] - center, axis=1) for center in well_centers])
min_dist_final = np.min(dist_final_to_wells, axis=0)
in_well_final = min_dist_final <= well_radius

print(f"\n{'='*60}")
print("LAWGD Results (SDMD Spectrum):")
print(f"{'='*60}")
print(f"Initial particles (outside anchors): {n_particles}")
print(f"Final particles near anchors: {np.sum(in_well_final)} ({100*np.sum(in_well_final)/n_particles:.1f}%)")
print(f"Final particles still outside: {n_particles - np.sum(in_well_final)} ({100*(n_particles - np.sum(in_well_final))/n_particles:.1f}%)")
print(f"{'='*60}")

print("\n" + "="*60)
print("Summary:")
print("="*60)
print("This LAWGD implementation uses:")
print("  ✓ SDMD spectrum (eigenvalues + eigenvectors from Cell 13)")
print("  ✓ DM kernel functions (evaluate_kernel_at_points, compute_kernel_gradient)")
print("  ✓ DM normalization parameters (epsilon, p_x, D_y from Cell 18)")
print("="*60)

In [None]:
# ============================================================
# Visualization 1: Initial vs Final Positions (Koopman(SDMD))
# ============================================================

fig, ax = plt.subplots(figsize=(8, 7))

ax.contourf(X_grid, Y_grid, density_grid, levels=30, cmap='Blues', alpha=0.7)
ax.scatter(X_tar[:, 0], X_tar[:, 1], s=3, c='lightgray', alpha=0.2, label='Training data')
ax.scatter(X_lawgd_traj_sdmd[:, 0, 0], X_lawgd_traj_sdmd[:, 1, 0],
           s=25, c='red', marker='o', label='Initial (outside anchors)', zorder=5)
ax.scatter(X_lawgd_traj_sdmd[:, 0, -1], X_lawgd_traj_sdmd[:, 1, -1],
           s=35, facecolors='none', edgecolors='magenta', linewidth=1.5,
           label='Final', zorder=15)

for idx, center in enumerate(well_centers):
    circle = Circle(center, well_radius, fill=False, edgecolor='green',
                   linewidth=2, linestyle='--')
    ax.add_patch(circle)

ax.scatter(well_centers[:, 0], well_centers[:, 1], s=100, c='green',
           marker='*', zorder=10)
ax.set_xlabel('Latent dim 1', fontsize=12)
ax.set_ylabel('Latent dim 2', fontsize=12)
ax.set_title('KSWGD: Initial vs Final Positions', fontsize=14, fontweight='bold')
ax.set_xlim(latent_bounds[0])
ax.set_ylim(latent_bounds[1])
ax.set_aspect('equal')
ax.legend(loc='upper right', fontsize=9)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# ============================================================
# Visualization 2: Trajectories (Koopman(SDMD))
# ============================================================

fig, ax = plt.subplots(figsize=(8, 7))

ax.contourf(X_grid, Y_grid, density_grid, levels=30, cmap='Blues', alpha=0.7)

n_show_traj = min(15, n_particles)
for i in range(n_show_traj):
    ax.plot(X_lawgd_traj_sdmd[i, 0, :], X_lawgd_traj_sdmd[i, 1, :],
            alpha=0.5, linewidth=1, color='gray')

ax.scatter(X_lawgd_traj_sdmd[:n_show_traj, 0, 0], X_lawgd_traj_sdmd[:n_show_traj, 1, 0],
           s=40, c='red', marker='o', zorder=5, label='Start')
ax.scatter(X_lawgd_traj_sdmd[:n_show_traj, 0, -1], X_lawgd_traj_sdmd[:n_show_traj, 1, -1],
           s=50, facecolors='none', edgecolors='magenta',
           linewidth=1.5, zorder=15, label='End')

for idx, center in enumerate(well_centers):
    circle = Circle(center, well_radius, fill=False, edgecolor='green',
                   linewidth=2, linestyle='--')
    ax.add_patch(circle)

ax.scatter(well_centers[:, 0], well_centers[:, 1], s=100, c='green',
           marker='*', zorder=10)
ax.set_xlabel('Latent dim 1', fontsize=12)
ax.set_ylabel('Latent dim 2', fontsize=12)
ax.set_title(f'KSWGD Trajectories (first {n_show_traj} particles)',
             fontsize=14, fontweight='bold')
ax.set_xlim(latent_bounds[0])
ax.set_ylim(latent_bounds[1])
ax.set_aspect('equal')
ax.legend(loc='upper right', fontsize=10)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Convergence Analysis: Koopman(SDMD) with Metrics Tracking (IID Data)

Run Koopman(SDMD) for 1000 steps and track convergence metrics every 100 steps

In [None]:
# ============================================================
# Koopman(SDMD) with Metrics Tracking (0-1000 steps, every 100 steps)
# ============================================================

from matplotlib.patches import Circle

print("="*60)
print("Koopman(SDMD): Running 1000 iterations with metrics tracking...")
print("="*60)

# Step 1: Prepare SDMD Koopman spectrum (reuse from previous cell)
eigvals_K_sdmd_analysis = eigvals_K_sdmd.copy()
eigvecs_K_sdmd_analysis = eigvecs_K_sdmd.copy()
lambda_gen_sdmd_analysis = lambda_gen_full_sdmd.copy()

# Mode selection (same as before)
n_skip = 1
eig_threshold = 0.01
eigvals_after_skip = eigvals_K_sdmd_analysis[n_skip:]
valid_mask = eigvals_after_skip.real > eig_threshold
n_valid = np.sum(valid_mask)

k_modes_sdmd = n_valid
mode_start_sdmd = n_skip
mode_end_sdmd = mode_start_sdmd + k_modes_sdmd

print(f"SDMD modes: {k_modes_sdmd} (from {mode_start_sdmd+1} to {mode_end_sdmd})")

# Build inverse generator weights
lambda_gen_selected_sdmd = lambda_gen_sdmd_analysis[mode_start_sdmd:mode_end_sdmd]
tol_gen_sdmd = 1e-6
lambda_gen_inv_selected_sdmd = np.zeros(k_modes_sdmd, dtype=complex)
mask_nonzero_sdmd = np.abs(lambda_gen_selected_sdmd) > tol_gen_sdmd
lambda_gen_inv_selected_sdmd[mask_nonzero_sdmd] = 1.0 / lambda_gen_selected_sdmd[mask_nonzero_sdmd]

# Step 2: Setup parameters
n_particles = X_lawgd_init.shape[0]
n_iter_lawgd = 1000
h_lawgd = 1  # Step size
record_interval = 100  # Record metrics every 100 steps
n_records = n_iter_lawgd // record_interval + 1  # Include step 0

# Initialize trajectory storage
X_lawgd_traj_sdmd = np.zeros((n_particles, 2, n_iter_lawgd + 1))
X_lawgd_traj_sdmd[:, :, 0] = X_lawgd_init.copy()

# Initialize metrics storage
metrics_sdmd = {
    'steps': [],
    'well_coverage': [],  # Percentage of particles in wells
    'avg_potential': [],  # Average potential energy
    'movement_rate': []   # Average displacement per step
}

print(f"Particles: {n_particles}")
print(f"Iterations: {n_iter_lawgd}")
print(f"Recording interval: {record_interval} steps")
print(f"Total records: {n_records}")

# Record initial state (step 0)
coverage_0, avg_V_0, _ = compute_metrics(X_lawgd_traj_sdmd[:, :, 0])
metrics_sdmd['steps'].append(0)
metrics_sdmd['well_coverage'].append(coverage_0)
metrics_sdmd['avg_potential'].append(avg_V_0)
metrics_sdmd['movement_rate'].append(0.0)

print(f"\nInitial state (step 0):")
print(f"  Well coverage: {coverage_0:.2f}%")
print(f"  Avg potential: {avg_V_0:.4f}")

# Step 3: Run LAWGD iterations with SDMD spectrum
print(f"\n{'='*60}")
print("Running Koopman(SDMD) iterations...")
print(f"{'='*60}")

for t in range(n_iter_lawgd):
    X_curr = X_lawgd_traj_sdmd[:, :, t]
    
    # Evaluate kernel between current particles and training data (DM kernel)
    K_curr = evaluate_kernel_at_points(X_curr, X_tar)  # (n_particles, n_data)
    
    # Project onto SDMD eigenmodes: c = Φ_SDMD^T @ K(x, X_tar)
    eigvecs_selected_sdmd = eigvecs_K_sdmd_analysis[:, mode_start_sdmd:mode_end_sdmd]  # (n_data, k_modes)
    c = eigvecs_selected_sdmd.T @ K_curr.T  # (k_modes, n_particles)
    
    # Apply inverse generator: c_inv = Λ_gen^{-1} @ c
    c_inv = lambda_gen_inv_selected_sdmd[:, None] * c  # (k_modes, n_particles)
    
    # Project back to data space: f = Φ_SDMD @ c_inv
    f_inv = eigvecs_selected_sdmd @ c_inv  # (n_data, n_particles)
    
    # Compute gradient using DM kernel gradient
    grad_K = compute_kernel_gradient(X_curr, X_tar)  # (n_particles, n_data, 2)
    
    # Gradient update: ∇_x f = Σ_i f_inv[i] * ∇_x k(x, x_i)
    grad_update = np.zeros((n_particles, 2))
    for d_idx in range(2):
        grad_update[:, d_idx] = np.sum(
            grad_K[:, :, d_idx] * f_inv.T.real,  # Use real part
            axis=1
        )
    
    # Gradient descent step
    X_lawgd_traj_sdmd[:, :, t+1] = X_curr - h_lawgd * grad_update
    
    # Record metrics every record_interval steps
    if (t + 1) % record_interval == 0:
        X_prev = X_lawgd_traj_sdmd[:, :, t]
        X_next = X_lawgd_traj_sdmd[:, :, t+1]
        coverage, avg_V, movement = compute_metrics(X_next, X_prev)
        
        metrics_sdmd['steps'].append(t + 1)
        metrics_sdmd['well_coverage'].append(coverage)
        metrics_sdmd['avg_potential'].append(avg_V)
        metrics_sdmd['movement_rate'].append(movement)
    
    # Progress indicator
    if (t+1) % 200 == 0 or t == 0:
        print(f"\r  [SDMD] Iteration {t+1}/{n_iter_lawgd}  ", end='', flush=True)

print()  # Newline after loop
print("Koopman(SDMD) iteration complete!")

# Step 4: Final summary
final_coverage = metrics_sdmd['well_coverage'][-1]
final_potential = metrics_sdmd['avg_potential'][-1]
final_movement = metrics_sdmd['movement_rate'][-1]

print(f"\n{'='*60}")
print("Koopman(SDMD) Final Results:")
print(f"{'='*60}")
print(f"Final well coverage: {final_coverage:.2f}%")
print(f"Final avg potential: {final_potential:.4f}")
print(f"Final movement rate: {final_movement:.6f}")
print(f"Total records: {len(metrics_sdmd['steps'])}")
print(f"{'='*60}")

# Convert metrics to numpy arrays for easier plotting
for key in ['steps', 'well_coverage', 'avg_potential', 'movement_rate']:
    metrics_sdmd[key] = np.array(metrics_sdmd[key])

print("\n✓ Koopman(SDMD) metrics saved in 'metrics_sdmd' dictionary")

## Visual Inspection of Generated Digits

Instead of potential-well coverage curves, we decode the final LATENT particles from DMPS and Koopman(SDMD) back to 28×28 pixel space and check whether the synthesized digits resemble the MNIST targets. Each column shows a generated digit (top) and its nearest neighbor among the latent training samples (bottom).

In [None]:
# ============================================================
# Generated Digit Grid: DMPS final particles
# ============================================================

def _decode_latents(latent_batch):
    if latent_batch.ndim == 1:
        latent_batch = latent_batch[None, :]
    return mnist_pipeline.inverse_transform(latent_batch)

def _nearest_latent_neighbors(latent_batch, reference_latent):
    diff = latent_batch[:, None, :] - reference_latent[None, :, :]
    dist = np.linalg.norm(diff, axis=2)
    nearest_idx = np.argmin(dist, axis=1)
    return reference_latent[nearest_idx]

def plot_generated_vs_reference(latent_batch, title, reference_latent=X_tar, n_cols=8):
    if latent_batch.ndim == 3:
        latent_final = latent_batch[:, :, -1]
    else:
        latent_final = latent_batch

    n_cols = min(n_cols, latent_final.shape[0])
    if n_cols == 0:
        raise ValueError("No latent samples available for visualization.")

    sample_idx = np.random.choice(latent_final.shape[0], size=n_cols, replace=False)
    latent_samples = latent_final[sample_idx]

    generated_imgs = _decode_latents(latent_samples)
    reference_latent_batch = _nearest_latent_neighbors(latent_samples, reference_latent)
    reference_imgs = _decode_latents(reference_latent_batch)

    fig, axes = plt.subplots(2, n_cols, figsize=(1.6 * n_cols, 3.6))
    for col in range(n_cols):
        axes[0, col].imshow(generated_imgs[col], cmap='gray')
        axes[0, col].set_title(f'gen #{col+1}', fontsize=9)
        axes[0, col].axis('off')
        axes[1, col].imshow(reference_imgs[col], cmap='gray')
        axes[1, col].set_title('ref', fontsize=8, pad=2)
        axes[1, col].axis('off')

    fig.suptitle(title + "\nTop: generated | Bottom: nearest training digit", fontsize=14)
    plt.tight_layout()
    plt.show()

plot_generated_vs_reference(X_lawgd_traj_dm[:, :, -1], "DMPS final decoded digits")

In [None]:
# ============================================================
# Generated Digit Grid: Koopman(SDMD) final particles
# ============================================================

plot_generated_vs_reference(
    X_lawgd_traj_sdmd[:, :, -1],
    "Koopman(SDMD) final decoded digits"
)

In [None]:
# ============================================================
# Pixel-space quality diagnostics
# ============================================================

def compute_pixel_stats(latent_batch, label, reference_latent=X_tar, n_eval=128):
    if latent_batch.ndim == 3:
        latent_final = latent_batch[:, :, -1]
    else:
        latent_final = latent_batch

    n_eval = min(n_eval, latent_final.shape[0])
    idx = np.random.choice(latent_final.shape[0], size=n_eval, replace=False)
    samples = latent_final[idx]

    generated_imgs = _decode_latents(samples)
    reference_latent_batch = _nearest_latent_neighbors(samples, reference_latent)
    reference_imgs = _decode_latents(reference_latent_batch)

    mse = np.mean((generated_imgs - reference_imgs) ** 2)
    mae = np.mean(np.abs(generated_imgs - reference_imgs))

    return {
        'method': label,
        'n_eval': n_eval,
        'mse': mse,
        'mae': mae
    }

stats_dm = compute_pixel_stats(X_lawgd_traj_dm, 'DMPS')
stats_sdmd = compute_pixel_stats(X_lawgd_traj_sdmd, 'Koopman(SDMD)')

print("\n" + "="*70)
print("PIXEL-SPACE QUALITY SUMMARY (lower is better)")
print("="*70)
for stats in [stats_dm, stats_sdmd]:
    print(f"{stats['method']}: n_eval={stats['n_eval']}")
    print(f"  • Mean squared error:     {stats['mse']:.6f}")
    print(f"  • Mean absolute error:   {stats['mae']:.6f}")
print("="*70)