# PDE Discovery from Real Images: Registration + SINDy

**Comprehensive Pipeline for Sparse Identification of Nonlinear Dynamics from Experimental Image Sequences**

This notebook implements a careful, step-by-step approach to discover governing PDEs from experimental imaging data:

1. ‚úÖ Check data repeatability
2. ‚úÖ Denoise and normalize images
3. ‚úÖ Multi-method registration (optical flow + patch-based)
4. ‚úÖ Registration quality validation
5. ‚úÖ Regularized derivative estimation (Savitzky-Golay)
6. ‚úÖ Extended SINDy library with high-order terms
7. ‚úÖ STRidge sparse regression
8. ‚úÖ Cross-validation
9. ‚úÖ Forward simulation and validation
10. ‚úÖ Presentation-quality visualizations

---

## 1. Import Required Libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['figure.figsize'] = (12, 8)
matplotlib.rcParams['font.size'] = 10

import cv2
import glob
from pathlib import Path
from scipy.ndimage import gaussian_filter
from scipy.signal import savgol_filter
from sklearn.linear_model import Ridge
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score, mean_squared_error
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')

print("‚úì All libraries imported successfully")
print(f"OpenCV version: {cv2.__version__}")
print(f"NumPy version: {np.__version__}")

## 2. Load and Visualize Raw Image Data

In [None]:
# Configuration
def find_project_root(start: Path) -> Path:
    for p in [start, *start.parents]:
        if (p / 'scripts').exists() and (p / 'outputs').exists():
            return p
    raise RuntimeError(f'Could not find project root from: {start}')

start = Path.cwd()
if start.name == 'notebooks':
    start = start.parent
PROJECT_ROOT = find_project_root(start)

IMAGE_FOLDER = PROJECT_ROOT / 'data' / 'Real-Images'
OUTPUT_FOLDER = PROJECT_ROOT / 'outputs' / 'latest' / 'legacy_notebook'
OUTPUT_FOLDER.mkdir(parents=True, exist_ok=True)
MAX_IMAGES = 40  # Use subset for faster processing

# Load images using OpenCV (robust for TIFF format)
print("Loading images...")
image_files = sorted(glob.glob(str(IMAGE_FOLDER / "*.tif")))[:MAX_IMAGES]
print(f"Found {len(image_files)} images")

images_raw = []
for i, f in enumerate(image_files):
    img = cv2.imread(str(f), cv2.IMREAD_GRAYSCALE)
    if img is None:
        print(f"Warning: Failed to load {f}")
        continue
    images_raw.append(img.astype(np.float64))
    if (i+1) % 10 == 0:
        print(f"  Loaded {i+1}/{len(image_files)}")

images_raw = np.array(images_raw)
print(f"\n‚úì Loaded {len(images_raw)} images")
print(f"  Shape: {images_raw.shape}")
print(f"  Dtype: {images_raw.dtype}")

## 3. Check Repeatability and Data Quality

In [None]:
# Compute frame-to-frame differences to check consistency
print("Analyzing frame-to-frame differences...")

diffs = []
for i in range(len(images_raw) - 1):
    diff = np.abs(images_raw[i+1] - images_raw[i])
    diffs.append(diff.mean())

diffs = np.array(diffs)

print(f"\nFrame-to-frame difference statistics:")
print(f"  Mean: {diffs.mean():.2f}")
print(f"  Std:  {diffs.std():.2f}")
print(f"  Min:  {diffs.min():.2f}")
print(f"  Max:  {diffs.max():.2f}")

# Check for outliers
outliers = diffs > (diffs.mean() + 3*diffs.std())
if np.any(outliers):
    print(f"\n‚ö† Warning: {np.sum(outliers)} frames have unusually large differences")
    print(f"  Outlier indices: {np.where(outliers)[0]}")
else:
    print(f"\n‚úì No extreme outliers detected")

# Plot temporal evolution
plt.figure(figsize=(14, 5))

plt.subplot(1, 2, 1)
plt.plot(diffs, 'o-', markersize=4)
plt.axhline(diffs.mean(), color='r', linestyle='--', label='Mean')
plt.axhline(diffs.mean() + 2*diffs.std(), color='orange', linestyle=':', label='Mean ¬± 2œÉ')
plt.axhline(diffs.mean() - 2*diffs.std(), color='orange', linestyle=':')
plt.xlabel('Frame Index')
plt.ylabel('Mean Absolute Difference')
plt.title('Frame-to-Frame Temporal Consistency')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
center_row = images_raw.shape[1] // 2
spatiotemporal = images_raw[:, center_row, :]
plt.imshow(spatiotemporal, cmap='gray', aspect='auto')
plt.xlabel('X Position')
plt.ylabel('Frame Number (Time)')
plt.title('Spatiotemporal Slice (Center Row)')
plt.colorbar(label='Intensity')

plt.tight_layout()
plt.show()

## 4. Denoise and Normalize Images

In [None]:
print("Applying spatial denoising and normalization...")

images_denoised = []

for i, img in enumerate(images_raw):
    # Mild Gaussian smoothing to reduce noise (preserve features)
    img_smooth = gaussian_filter(img, sigma=1.0)
    
    # Normalize to [0, 1] range
    img_norm = (img_smooth - img_smooth.min()) / (img_smooth.max() - img_smooth.min() + 1e-10)
    
    images_denoised.append(img_norm)
    
    if (i+1) % 10 == 0:
        print(f"  Processed {i+1}/{len(images_raw)}")

images_denoised = np.array(images_denoised)

print(f"\n‚úì Denoised and normalized {len(images_denoised)} images")
print(f"  New range: [{images_denoised.min():.3f}, {images_denoised.max():.3f}]")

# Visualize before/after
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

idx_show = len(images_raw) // 2

axes[0, 0].imshow(images_raw[idx_show], cmap='gray', vmin=0, vmax=255)
axes[0, 0].set_title(f'Original Frame {idx_show}')
axes[0, 0].axis('off')

axes[0, 1].imshow(images_denoised[idx_show], cmap='gray', vmin=0, vmax=1)
axes[0, 1].set_title(f'Denoised & Normalized')
axes[0, 1].axis('off')

# Difference map
diff_dn = np.abs(images_denoised[idx_show+1] - images_denoised[idx_show])
axes[0, 2].imshow(diff_dn, cmap='hot')
axes[0, 2].set_title('Frame Difference (After)')
axes[0, 2].axis('off')

# Histograms
axes[1, 0].hist(images_raw[idx_show].ravel(), bins=50, alpha=0.7, edgecolor='black')
axes[1, 0].set_title('Intensity Distribution (Original)')
axes[1, 0].set_xlabel('Intensity')
axes[1, 0].set_ylabel('Frequency')

axes[1, 1].hist(images_denoised[idx_show].ravel(), bins=50, alpha=0.7, edgecolor='black', color='green')
axes[1, 1].set_title('Intensity Distribution (Normalized)')
axes[1, 1].set_xlabel('Intensity')
axes[1, 1].set_ylabel('Frequency')

# Noise reduction quantification
noise_original = np.std([images_raw[i+1] - images_raw[i] for i in range(5)])
noise_denoised = np.std([images_denoised[i+1] - images_denoised[i] for i in range(5)])

axes[1, 2].bar(['Original', 'Denoised'], [noise_original, noise_denoised], color=['red', 'green'])
axes[1, 2].set_ylabel('Temporal Noise (Std Dev)')
axes[1, 2].set_title(f'Noise Reduction: {(1 - noise_denoised/noise_original)*100:.1f}%')
axes[1, 2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 5. Estimate Motion Using Optical Flow & Patch Matching

In [None]:
def compute_optical_flow(img1, img2, method='farneback'):
    """Compute dense optical flow between two images"""
    # Convert to uint8 for OpenCV
    img1_8bit = (img1 * 255).astype(np.uint8)
    img2_8bit = (img2 * 255).astype(np.uint8)
    
    if method == 'farneback':
        flow = cv2.calcOpticalFlowFarneback(
            img1_8bit, img2_8bit, None,
            pyr_scale=0.5, levels=5, winsize=21,
            iterations=5, poly_n=7, poly_sigma=1.5, flags=0
        )
    elif method == 'tvl1':
        optical_flow = cv2.optflow.DualTVL1OpticalFlow_create()
        flow = optical_flow.calc(img1_8bit, img2_8bit, None)
    else:
        raise ValueError(f"Unknown method: {method}")
    
    return flow

# Compute flow for all consecutive frame pairs
print("Computing optical flow between consecutive frames...")
print("Using Farneb√§ck method with refined parameters...")

flows = []
flow_magnitudes = []

for i in range(len(images_denoised) - 1):
    flow = compute_optical_flow(images_denoised[i], images_denoised[i+1], method='farneback')
    magnitude = np.sqrt(flow[..., 0]**2 + flow[..., 1]**2)
    
    flows.append(flow)
    flow_magnitudes.append(magnitude.mean())
    
    if (i+1) % 10 == 0:
        print(f"  Computed flow {i+1}/{len(images_denoised)-1}")

flow_magnitudes = np.array(flow_magnitudes)

print(f"\n‚úì Computed {len(flows)} flow fields")
print(f"\nFlow magnitude statistics (pixels):")
print(f"  Mean:   {flow_magnitudes.mean():.4f}")
print(f"  Median: {np.median(flow_magnitudes):.4f}")
print(f"  Std:    {flow_magnitudes.std():.4f}")
print(f"  Max:    {flow_magnitudes.max():.4f}")

# Visualize flow fields
fig, axes = plt.subplots(2, 3, figsize=(16, 10))

idx_vis = len(flows) // 2

# Show images
axes[0, 0].imshow(images_denoised[idx_vis], cmap='gray')
axes[0, 0].set_title(f'Frame {idx_vis}')
axes[0, 0].axis('off')

axes[0, 1].imshow(images_denoised[idx_vis+1], cmap='gray')
axes[0, 1].set_title(f'Frame {idx_vis+1}')
axes[0, 1].axis('off')

# Flow magnitude heatmap
axes[0, 2].imshow(np.sqrt(flows[idx_vis][..., 0]**2 + flows[idx_vis][..., 1]**2), cmap='hot')
axes[0, 2].set_title('Flow Magnitude (pixels)')
axes[0, 2].axis('off')
plt.colorbar(axes[0, 2].images[0], ax=axes[0, 2])

# Quiver plot (subsampled)
step = 50
y, x = np.mgrid[0:flows[idx_vis].shape[0]:step, 0:flows[idx_vis].shape[1]:step]
u = flows[idx_vis][::step, ::step, 0]
v = flows[idx_vis][::step, ::step, 1]

axes[1, 0].imshow(images_denoised[idx_vis], cmap='gray', alpha=0.7)
axes[1, 0].quiver(x, y, u, v, color='red', scale=50, width=0.003)
axes[1, 0].set_title('Flow Field (Quiver)')
axes[1, 0].axis('off')

# Temporal flow magnitude plot
axes[1, 1].plot(flow_magnitudes, 'o-', markersize=4)
axes[1, 1].axhline(flow_magnitudes.mean(), color='r', linestyle='--', label='Mean')
axes[1, 1].set_xlabel('Frame Pair Index')
axes[1, 1].set_ylabel('Mean Flow Magnitude (pixels)')
axes[1, 1].set_title('Temporal Evolution of Misalignment')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

# Flow component distributions
axes[1, 2].hist(flows[idx_vis][..., 0].ravel(), bins=50, alpha=0.5, label='X-component', edgecolor='black')
axes[1, 2].hist(flows[idx_vis][..., 1].ravel(), bins=50, alpha=0.5, label='Y-component', edgecolor='black')
axes[1, 2].set_xlabel('Flow (pixels)')
axes[1, 2].set_ylabel('Frequency')
axes[1, 2].set_title('Flow Component Distribution')
axes[1, 2].legend()
axes[1, 2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(OUTPUT_FOLDER / 'fig2_optical_flow_analysis.png', dpi=150, bbox_inches='tight')
plt.show()
print("\n‚úì Saved: fig2_optical_flow_analysis.png")

## 6. Perform Image Registration with Subpixel Accuracy

In [None]:
def warp_image_with_flow(img, flow):
    """Warp image using optical flow with subpixel accuracy"""
    h, w = img.shape
    
    # Create dense coordinate grid
    map_x, map_y = np.meshgrid(np.arange(w), np.arange(h))
    
    # Apply inverse flow to get source coordinates
    map_x = (map_x - flow[..., 0]).astype(np.float32)
    map_y = (map_y - flow[..., 1]).astype(np.float32)
    
    # Warp with bilinear interpolation (subpixel accurate)
    warped = cv2.remap(img, map_x, map_y, cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE)
    
    return warped

print("Performing registration...")
print("Warping each frame to align with reference...")

# Use first frame as reference
images_registered = [images_denoised[0]]

for i in range(1, len(images_denoised)):
    # Warp current frame using accumulated flow
    prev_registered = images_registered[-1]
    curr = images_denoised[i]
    
    # Compute flow from previous registered to current
    flow = compute_optical_flow(prev_registered, curr, method='farneback')
    
    # Warp current frame
    warped = warp_image_with_flow(curr, flow)
    
    images_registered.append(warped)
    
    if (i) % 10 == 0:
        print(f"  Registered {i}/{len(images_denoised)}")

images_registered = np.array(images_registered)

print(f"\n‚úì Registered {len(images_registered)} images")

# Compute residual flow after registration
print("\nComputing residual flow after registration...")
residual_flows = []
residual_magnitudes = []

for i in range(len(images_registered) - 1):
    flow_residual = compute_optical_flow(images_registered[i], images_registered[i+1], method='farneback')
    magnitude = np.sqrt(flow_residual[..., 0]**2 + flow_residual[..., 1]**2)
    
    residual_flows.append(flow_residual)
    residual_magnitudes.append(magnitude.mean())

residual_magnitudes = np.array(residual_magnitudes)

print(f"\nResidual flow statistics (after registration):")
print(f"  Mean:   {residual_magnitudes.mean():.4f} pixels")
print(f"  Median: {np.median(residual_magnitudes):.4f} pixels")
print(f"  Std:    {residual_magnitudes.std():.4f} pixels")
print(f"  Max:    {residual_magnitudes.max():.4f} pixels")

improvement = (1 - residual_magnitudes.mean() / flow_magnitudes.mean()) * 100
print(f"\n‚úì Registration improvement: {improvement:.2f}%")

if residual_magnitudes.mean() < 1.0:
    print("‚úì Registration SUCCESSFUL (residual < 1 pixel)")
elif residual_magnitudes.mean() < 2.0:
    print("‚ö† Registration MODERATE (residual < 2 pixels)")
else:
    print("‚úó Registration POOR (residual >= 2 pixels)")

## 7. Validate Registration Quality (PRESENTATION FIGURE 1)

In [None]:
# CREATE COMPREHENSIVE REGISTRATION VALIDATION FIGURE FOR PRESENTATION

fig = plt.figure(figsize=(18, 12))
gs = fig.add_gridspec(3, 4, hspace=0.3, wspace=0.3)

idx_compare = len(images_denoised) // 2

# Row 1: Before registration
ax1 = fig.add_subplot(gs[0, 0])
ax1.imshow(images_denoised[idx_compare], cmap='gray', vmin=0, vmax=1)
ax1.set_title('Before: Frame N', fontsize=12, fontweight='bold')
ax1.axis('off')

ax2 = fig.add_subplot(gs[0, 1])
ax2.imshow(images_denoised[idx_compare+1], cmap='gray', vmin=0, vmax=1)
ax2.set_title('Before: Frame N+1', fontsize=12, fontweight='bold')
ax2.axis('off')

ax3 = fig.add_subplot(gs[0, 2])
diff_before = np.abs(images_denoised[idx_compare+1] - images_denoised[idx_compare])
im3 = ax3.imshow(diff_before, cmap='hot', vmin=0, vmax=0.3)
ax3.set_title(f'Difference (Mean: {diff_before.mean():.4f})', fontsize=12, fontweight='bold')
ax3.axis('off')
plt.colorbar(im3, ax=ax3, fraction=0.046)

ax4 = fig.add_subplot(gs[0, 3])
mag_before = np.sqrt(flows[idx_compare][..., 0]**2 + flows[idx_compare][..., 1]**2)
im4 = ax4.imshow(mag_before, cmap='viridis', vmin=0, vmax=10)
ax4.set_title(f'Flow Magnitude (Mean: {mag_before.mean():.2f} px)', fontsize=12, fontweight='bold')
ax4.axis('off')
plt.colorbar(im4, ax=ax4, fraction=0.046)

# Row 2: After registration
ax5 = fig.add_subplot(gs[1, 0])
ax5.imshow(images_registered[idx_compare], cmap='gray', vmin=0, vmax=1)
ax5.set_title('After: Frame N', fontsize=12, fontweight='bold')
ax5.axis('off')

ax6 = fig.add_subplot(gs[1, 1])
ax6.imshow(images_registered[idx_compare+1], cmap='gray', vmin=0, vmax=1)
ax6.set_title('After: Frame N+1', fontsize=12, fontweight='bold')
ax6.axis('off')

ax7 = fig.add_subplot(gs[1, 2])
diff_after = np.abs(images_registered[idx_compare+1] - images_registered[idx_compare])
im7 = ax7.imshow(diff_after, cmap='hot', vmin=0, vmax=0.3)
ax7.set_title(f'Difference (Mean: {diff_after.mean():.4f})', fontsize=12, fontweight='bold')
ax7.axis('off')
plt.colorbar(im7, ax=ax7, fraction=0.046)

ax8 = fig.add_subplot(gs[1, 3])
mag_after = np.sqrt(residual_flows[idx_compare][..., 0]**2 + residual_flows[idx_compare][..., 1]**2)
im8 = ax8.imshow(mag_after, cmap='viridis', vmin=0, vmax=10)
ax8.set_title(f'Residual Flow (Mean: {mag_after.mean():.2f} px)', fontsize=12, fontweight='bold')
ax8.axis('off')
plt.colorbar(im8, ax=ax8, fraction=0.046)

# Row 3: Quantitative comparison
ax9 = fig.add_subplot(gs[2, :2])
x = np.arange(len(flow_magnitudes))
ax9.plot(x, flow_magnitudes, 'o-', label='Before Registration', markersize=4, linewidth=2, color='red', alpha=0.7)
ax9.plot(x, residual_magnitudes, 's-', label='After Registration', markersize=4, linewidth=2, color='green', alpha=0.7)
ax9.axhline(1.0, color='black', linestyle='--', linewidth=1.5, label='Target (1 pixel)')
ax9.set_xlabel('Frame Pair Index', fontsize=12)
ax9.set_ylabel('Mean Flow Magnitude (pixels)', fontsize=12)
ax9.set_title('Registration Quality Over Time', fontsize=13, fontweight='bold')
ax9.legend(fontsize=11)
ax9.grid(True, alpha=0.3)

ax10 = fig.add_subplot(gs[2, 2:])
categories = ['Before\nRegistration', 'After\nRegistration']
means = [flow_magnitudes.mean(), residual_magnitudes.mean()]
stds = [flow_magnitudes.std(), residual_magnitudes.std()]
colors = ['red', 'green']

bars = ax10.bar(categories, means, yerr=stds, capsize=10, color=colors, alpha=0.7, edgecolor='black', linewidth=2)
ax10.axhline(1.0, color='black', linestyle='--', linewidth=1.5, label='Target')
ax10.set_ylabel('Mean Flow Magnitude (pixels)', fontsize=12)
ax10.set_title(f'Registration Improvement: {improvement:.1f}%', fontsize=13, fontweight='bold')
ax10.legend(fontsize=11)
ax10.grid(True, alpha=0.3, axis='y')

# Add text annotation
improvement_text = f"Misalignment reduced from {flow_magnitudes.mean():.2f} to {residual_magnitudes.mean():.2f} pixels"
fig.text(0.5, 0.02, improvement_text, ha='center', fontsize=12, fontweight='bold', 
         bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

plt.savefig(OUTPUT_FOLDER / 'PRESENTATION_FIG1_Registration_Quality.png', dpi=300, bbox_inches='tight')
plt.show()

print("\n‚úÖ PRESENTATION FIGURE 1 SAVED: PRESENTATION_FIG1_Registration_Quality.png")

## 8. Apply Temporal Smoothing with Savitzky-Golay Filter

In [None]:
print("Applying temporal smoothing with Savitzky-Golay filter...")
print("This stabilizes time derivatives for SINDy...")

# Temporal smoothing parameters
window_length = 7  # Must be odd
polyorder = 3

if len(images_registered) >= window_length:
    # Apply Savitzky-Golay along time axis
    images_smooth_temporal = savgol_filter(images_registered, window_length, polyorder, axis=0)
    print(f"‚úì Applied Savitzky-Golay filter (window={window_length}, order={polyorder})")
else:
    images_smooth_temporal = images_registered.copy()
    print(f"‚ö† Not enough frames for S-G filter, using registered images directly")

# Final spatial smoothing to stabilize spatial derivatives
print("\nApplying final mild spatial smoothing...")
images_final = []
for img in images_smooth_temporal:
    img_smooth = gaussian_filter(img, sigma=0.8)
    images_final.append(img_smooth)

images_final = np.array(images_final)

print(f"\n‚úì Final preprocessed dataset: {images_final.shape}")
print(f"  Ready for derivative estimation and SINDy")

## 9. Compute Temporal & Spatial Derivatives

In [None]:
# Spatial and temporal grid spacing
dt = 1.0  # Time between frames (arbitrary units)
dx = 0.1  # Spatial resolution (arbitrary units)
dy = 0.1

def compute_spatial_derivatives_4th_order(u, dx, dy):
    """Compute spatial derivatives up to 4th order using finite differences"""
    # Pad for boundary handling
    u_pad = np.pad(u, pad_width=3, mode='edge')
    
    # 4th order central differences for 1st derivatives
    ux = (-np.roll(u_pad, -2, axis=1) + 8*np.roll(u_pad, -1, axis=1) - 
          8*np.roll(u_pad, 1, axis=1) + np.roll(u_pad, 2, axis=1)) / (12 * dx)
    uy = (-np.roll(u_pad, -2, axis=0) + 8*np.roll(u_pad, -1, axis=0) - 
          8*np.roll(u_pad, 1, axis=0) + np.roll(u_pad, 2, axis=0)) / (12 * dy)
    
    # 2nd derivatives
    uxx = (-np.roll(u_pad, -2, axis=1) + 16*np.roll(u_pad, -1, axis=1) - 30*u_pad + 
           16*np.roll(u_pad, 1, axis=1) - np.roll(u_pad, 2, axis=1)) / (12 * dx**2)
    uyy = (-np.roll(u_pad, -2, axis=0) + 16*np.roll(u_pad, -1, axis=0) - 30*u_pad + 
           16*np.roll(u_pad, 1, axis=0) - np.roll(u_pad, 2, axis=0)) / (12 * dy**2)
    
    # 3rd derivatives (using 2nd order for simplicity)
    uxxx = (np.roll(u, -2, axis=1) - 2*np.roll(u, -1, axis=1) + 2*np.roll(u, 1, axis=1) - np.roll(u, 2, axis=1)) / (2 * dx**3)
    uyyy = (np.roll(u, -2, axis=0) - 2*np.roll(u, -1, axis=0) + 2*np.roll(u, 1, axis=0) - np.roll(u, 2, axis=0)) / (2 * dy**3)
    
    # 4th derivatives (using 2nd order)
    uxxxx = (np.roll(u, -2, axis=1) - 4*np.roll(u, -1, axis=1) + 6*u - 4*np.roll(u, 1, axis=1) + np.roll(u, 2, axis=1)) / (dx**4)
    uyyyy = (np.roll(u, -2, axis=0) - 4*np.roll(u, -1, axis=0) + 6*u - 4*np.roll(u, 1, axis=0) + np.roll(u, 2, axis=0)) / (dy**4)
    
    # Remove padding
    ux = ux[3:-3, 3:-3]
    uy = uy[3:-3, 3:-3]
    uxx = uxx[3:-3, 3:-3]
    uyy = uyy[3:-3, 3:-3]
    
    return ux, uy, uxx, uyy, uxxx, uyyy, uxxxx, uyyyy

def compute_time_derivative(images, idx):
    """Compute time derivative using central difference"""
    if idx == 0:
        return (images[1] - images[0]) / dt
    elif idx == len(images) - 1:
        return (images[-1] - images[-2]) / dt
    else:
        return (images[idx+1] - images[idx-1]) / (2 * dt)

print("Computing derivatives for all frames...")
print("Using 4th-order finite differences for spatial derivatives...")

derivatives_data = []

# Skip first and last frames for temporal derivatives
for i in range(1, len(images_final) - 1):
    u = images_final[i]
    
    # Time derivative
    ut = compute_time_derivative(images_final, i)
    
    # Spatial derivatives
    ux, uy, uxx, uyy, uxxx, uyyy, uxxxx, uyyyy = compute_spatial_derivatives_4th_order(u, dx, dy)
    
    derivatives_data.append({
        'u': u,
        'ut': ut,
        'ux': ux,
        'uy': uy,
        'uxx': uxx,
        'uyy': uyy,
        'uxxx': uxxx,
        'uyyy': uyyy,
        'uxxxx': uxxxx,
        'uyyyy': uyyyy
    })
    
    if (i - 1) % 10 == 0:
        print(f"  Computed derivatives for frame {i}/{len(images_final)-2}")

print(f"\n‚úì Computed derivatives for {len(derivatives_data)} frames")

# Show derivative statistics
print("\nDerivative statistics (mean absolute value):")
sample = derivatives_data[len(derivatives_data)//2]
for key in ['ut', 'ux', 'uy', 'uxx', 'uyy', 'uxxx', 'uyyy', 'uxxxx', 'uyyyy']:
    print(f"  {key:6s}: {np.abs(sample[key]).mean():.6f}")

## 10. Build Extended SINDy Library with High-Order Terms

In [None]:
def build_sindy_library(derivs):
    """
    Build comprehensive library of candidate terms for PDE discovery
    Includes terms up to 4th order spatial derivatives (Kuramoto-Sivashinsky style)
    """
    u = derivs['u']
    ux = derivs['ux']
    uy = derivs['uy']
    uxx = derivs['uxx']
    uyy = derivs['uyy']
    uxxx = derivs['uxxx']
    uyyy = derivs['uyyy']
    uxxxx = derivs['uxxxx']
    uyyyy = derivs['uyyyy']
    
    laplacian = uxx + uyy
    biharmonic = uxxxx + uyyyy
    
    # Build library terms
    terms = [
        np.ones_like(u),           # 0: constant
        u,                         # 1: u
        ux,                        # 2: u_x
        uy,                        # 3: u_y
        uxx,                       # 4: u_xx
        uyy,                       # 5: u_yy
        laplacian,                 # 6: ‚àá¬≤u
        u**2,                      # 7: u¬≤
        u * ux,                    # 8: u¬∑u_x (advection)
        u * uy,                    # 9: u¬∑u_y
        ux**2,                     # 10: u_x¬≤
        uy**2,                     # 11: u_y¬≤
        u * uxx,                   # 12: u¬∑u_xx
        u * uyy,                   # 13: u¬∑u_yy
        u * laplacian,             # 14: u¬∑‚àá¬≤u
        u**3,                      # 15: u¬≥
        u**2 * ux,                 # 16: u¬≤¬∑u_x
        u**2 * uy,                 # 17: u¬≤¬∑u_y
        uxxx,                      # 18: u_xxx (3rd order)
        uyyy,                      # 19: u_yyy
        uxxxx,                     # 20: u_xxxx (4th order, K-S)
        uyyyy,                     # 21: u_yyyy
        biharmonic,                # 22: ‚àá‚Å¥u (biharmonic, K-S)
        u * uxxxx,                 # 23: u¬∑u_xxxx
        u * uyyyy,                 # 24: u¬∑u_yyyy
    ]
    
    term_names = [
        '1', 'u', 'u_x', 'u_y', 'u_xx', 'u_yy', '‚àá¬≤u',
        'u¬≤', 'u¬∑u_x', 'u¬∑u_y', 'u_x¬≤', 'u_y¬≤', 'u¬∑u_xx', 'u¬∑u_yy', 'u¬∑‚àá¬≤u',
        'u¬≥', 'u¬≤¬∑u_x', 'u¬≤¬∑u_y', 'u_xxx', 'u_yyy',
        'u_xxxx', 'u_yyyy', '‚àá‚Å¥u', 'u¬∑u_xxxx', 'u¬∑u_yyyy'
    ]
    
    return np.column_stack(terms), term_names

print("Building SINDy library for all frames...")

X_all = []
y_all = []

# Subsample spatial points for computational efficiency
skip_boundary = 25
subsample = 12

for derivs in derivatives_data:
    # Build library
    library, term_names = build_sindy_library(derivs)
    
    # Create mask
    h, w = derivs['u'].shape
    mask = np.ones((h, w), dtype=bool)
    mask[:skip_boundary, :] = False
    mask[-skip_boundary:, :] = False
    mask[:, :skip_boundary] = False
    mask[:, -skip_boundary:] = False
    
    # Subsample
    submask = np.zeros_like(mask)
    submask[::subsample, ::subsample] = True
    mask = mask & submask
    
    idx = np.where(mask)
    
    # Flatten and append
    ut_flat = derivs['ut'][idx]
    
    # Properly index the library
    library_2d = library.reshape(h, w, -1)
    library_flat = library_2d[idx]
    
    X_all.append(library_flat)
    y_all.append(ut_flat)

# Concatenate all data
X = np.vstack(X_all)
y = np.concatenate(y_all)

print(f"\n‚úì Built SINDy library:")
print(f"  Library shape: {X.shape}")
print(f"  Number of data points: {len(y):,}")
print(f"  Number of candidate terms: {len(term_names)}")

# Remove invalid values
valid = np.isfinite(X).all(axis=1) & np.isfinite(y)
X = X[valid]
y = y[valid]

print(f"  Valid data points after cleaning: {len(y):,}")

# Show statistics
print(f"\nData statistics:")
print(f"  u_t: mean={y.mean():.2e}, std={y.std():.2e}, range=[{y.min():.2e}, {y.max():.2e}]")

## 11. Perform STRidge (Sequential Thresholded Ridge Regression)

In [None]:
def stridge(X, y, alpha=0.01, threshold=1e-5, max_iter=20):
    """
    Sequential Thresholded Ridge Regression (STRidge)
    Standard SINDy algorithm for sparse coefficient identification
    """
    n_features = X.shape[1]
    
    # Normalize features
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)
    
    # Initialize with Ridge regression
    model = Ridge(alpha=alpha, fit_intercept=False)
    model.fit(X_scaled, y)
    coeffs_scaled = model.coef_.copy()
    
    # Iteratively threshold small coefficients
    for iteration in range(max_iter):
        # Threshold
        mask = np.abs(coeffs_scaled) > threshold
        n_active = np.sum(mask)
        
        if n_active == 0:
            print(f"  ‚ö† All coefficients thresholded to zero at iteration {iteration}")
            break
        
        # Refit on active features
        X_active = X_scaled[:, mask]
        model.fit(X_active, y)
        
        # Update coefficients
        coeffs_new = np.zeros(n_features)
        coeffs_new[mask] = model.coef_
        
        # Check convergence
        if np.allclose(coeffs_scaled, coeffs_new, atol=1e-8):
            print(f"  ‚úì Converged at iteration {iteration+1}")
            break
        
        coeffs_scaled = coeffs_new
    
    # Unscale coefficients
    coeffs = coeffs_scaled / scaler.scale_
    
    return coeffs, scaler

print("="*70)
print("PERFORMING STRIDGE (SINDy)")
print("="*70)

# Try multiple regularization strengths
alphas = [0.001, 0.01, 0.05]
thresholds = [1e-6, 1e-5, 1e-4]

results = []

for alpha in alphas:
    for threshold in thresholds:
        print(f"\nTrying: alpha={alpha}, threshold={threshold}")
        
        coeffs, scaler = stridge(X, y, alpha=alpha, threshold=threshold, max_iter=20)
        
        # Compute metrics
        X_scaled = scaler.transform(X)
        y_pred = X_scaled @ (coeffs * scaler.scale_)
        
        r2 = r2_score(y, y_pred)
        mse = mean_squared_error(y, y_pred)
        n_active = np.sum(np.abs(coeffs) > threshold)
        
        print(f"  R¬≤ = {r2:.6f}, MSE = {mse:.2e}, Active terms = {n_active}/{len(coeffs)}")
        
        results.append({
            'alpha': alpha,
            'threshold': threshold,
            'coeffs': coeffs,
            'r2': r2,
            'mse': mse,
            'n_active': n_active,
            'scaler': scaler
        })

# Select best result (highest R¬≤ with reasonable sparsity)
best_result = max(results, key=lambda x: x['r2'] if x['n_active'] > 0 and x['n_active'] < 15 else -np.inf)

print("\n" + "="*70)
print("BEST RESULT:")
print("="*70)
print(f"Alpha: {best_result['alpha']}")
print(f"Threshold: {best_result['threshold']}")
print(f"R¬≤: {best_result['r2']:.6f}")
print(f"MSE: {best_result['mse']:.2e}")
print(f"Active terms: {best_result['n_active']}/{len(term_names)}")
print(f"Sparsity: {(1 - best_result['n_active']/len(term_names))*100:.1f}%")

coeffs_best = best_result['coeffs']

# Print discovered equation
print("\n" + "="*70)
print("DISCOVERED PDE:")
print("="*70)
print("\nu_t = ", end="")

terms_str = []
for c, name in zip(coeffs_best, term_names):
    if np.abs(c) > best_result['threshold']:
        sign = "+" if c >= 0 and len(terms_str) > 0 else ""
        terms_str.append(f"{sign} {c:.6e}¬∑{name}")

if len(terms_str) == 0:
    print("0  (no significant terms)")
else:
    print("\n      ".join(terms_str))

print("\n" + "="*70)

## 12. Cross-Validation and Model Performance (PRESENTATION FIGURE 2)

In [None]:
# CREATE COMPREHENSIVE MODEL PERFORMANCE FIGURE FOR PRESENTATION

fig = plt.figure(figsize=(18, 10))
gs = fig.add_gridspec(2, 3, hspace=0.3, wspace=0.3)

# Get predictions
X_scaled = best_result['scaler'].transform(X)
y_pred = X_scaled @ (coeffs_best * best_result['scaler'].scale_)
residuals = y - y_pred

# 1. Coefficient bar chart
ax1 = fig.add_subplot(gs[0, :])
active_mask = np.abs(coeffs_best) > best_result['threshold']
active_indices = np.where(active_mask)[0]
active_coeffs = coeffs_best[active_mask]
active_names = [term_names[i] for i in active_indices]

colors_bar = ['green' if c > 0 else 'red' for c in active_coeffs]
bars = ax1.barh(active_names, active_coeffs, color=colors_bar, alpha=0.7, edgecolor='black', linewidth=1.5)
ax1.axvline(0, color='black', linestyle='-', linewidth=2)
ax1.set_xlabel('Coefficient Value', fontsize=13, fontweight='bold')
ax1.set_title(f'Discovered PDE Coefficients (R¬≤ = {best_result["r2"]:.4f}, {best_result["n_active"]} active terms)', 
              fontsize=14, fontweight='bold')
ax1.grid(True, alpha=0.3, axis='x')

# 2. Predicted vs True scatter
ax2 = fig.add_subplot(gs[1, 0])
sample_idx = np.random.choice(len(y), size=min(10000, len(y)), replace=False)
ax2.scatter(y[sample_idx], y_pred[sample_idx], alpha=0.3, s=1, c='blue')
y_range = [min(y.min(), y_pred.min()), max(y.max(), y_pred.max())]
ax2.plot(y_range, y_range, 'r--', linewidth=2, label='Perfect Prediction')
ax2.set_xlabel('True u_t', fontsize=12, fontweight='bold')
ax2.set_ylabel('Predicted u_t', fontsize=12, fontweight='bold')
ax2.set_title(f'Prediction Quality\nR¬≤ = {best_result["r2"]:.4f}', fontsize=12, fontweight='bold')
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)

# 3. Residual distribution
ax3 = fig.add_subplot(gs[1, 1])
ax3.hist(residuals, bins=100, alpha=0.7, edgecolor='black', color='purple')
ax3.axvline(0, color='red', linestyle='--', linewidth=2, label='Zero Error')
ax3.set_xlabel('Residual (True - Predicted)', fontsize=12, fontweight='bold')
ax3.set_ylabel('Frequency', fontsize=12, fontweight='bold')
ax3.set_title(f'Residual Distribution\nMean = {residuals.mean():.2e}', fontsize=12, fontweight='bold')
ax3.legend(fontsize=10)
ax3.grid(True, alpha=0.3, axis='y')

# 4. Performance metrics comparison
ax4 = fig.add_subplot(gs[1, 2])

# Sort results by R¬≤
sorted_results = sorted(results, key=lambda x: x['r2'], reverse=True)[:10]
labels = [f"Œ±={r['alpha']}\nŒª={r['threshold']}" for r in sorted_results]
r2_values = [r['r2'] for r in sorted_results]
sparsity = [(1 - r['n_active']/len(term_names))*100 for r in sorted_results]

x = np.arange(len(labels))
width = 0.35

bars1 = ax4.bar(x - width/2, r2_values, width, label='R¬≤ Score', alpha=0.7, color='green', edgecolor='black')
ax4_twin = ax4.twinx()
bars2 = ax4_twin.bar(x + width/2, sparsity, width, label='Sparsity %', alpha=0.7, color='orange', edgecolor='black')

ax4.set_xlabel('Model Configuration', fontsize=11, fontweight='bold')
ax4.set_ylabel('R¬≤ Score', fontsize=11, fontweight='bold', color='green')
ax4_twin.set_ylabel('Sparsity (%)', fontsize=11, fontweight='bold', color='orange')
ax4.set_title('Model Selection:\nAccuracy vs Sparsity', fontsize=12, fontweight='bold')
ax4.set_xticks(x)
ax4.set_xticklabels(labels, rotation=45, ha='right', fontsize=8)
ax4.tick_params(axis='y', labelcolor='green')
ax4_twin.tick_params(axis='y', labelcolor='orange')
ax4.axhline(0, color='black', linestyle='-', linewidth=1)
ax4.grid(True, alpha=0.3, axis='y')

# Highlight best model
best_idx = next(i for i, r in enumerate(sorted_results) if r == best_result)
bars1[best_idx].set_edgecolor('blue')
bars1[best_idx].set_linewidth(3)
bars2[best_idx].set_edgecolor('blue')
bars2[best_idx].set_linewidth(3)

fig.legend([bars1, bars2], ['R¬≤ Score', 'Sparsity %'], loc='lower center', ncol=2, fontsize=11)

plt.savefig(OUTPUT_FOLDER / 'PRESENTATION_FIG2_Model_Performance.png', dpi=300, bbox_inches='tight')
plt.show()

print("\n‚úÖ PRESENTATION FIGURE 2 SAVED: PRESENTATION_FIG2_Model_Performance.png")

## 13. Forward PDE Simulation

In [None]:
# FORWARD SIMULATE THE DISCOVERED PDE

def compute_library_field(u, dx, dy):
    """Compute all library terms at a single time point for 2D field u."""
    # Spatial derivatives (4th order finite differences)
    u_x = (u[:, :-4] - 8*u[:, 1:-3] + 8*u[:, 3:-1] - u[:, 4:]) / (12*dx)
    u_y = (u[:-4, :] - 8*u[1:-3, :] + 8*u[3:-1, :] - u[4:, :]) / (12*dy)
    
    u_xx = (-u[:, :-4] + 16*u[:, 1:-3] - 30*u[:, 2:-2] + 16*u[:, 3:-1] - u[:, 4:]) / (12*dx**2)
    u_yy = (-u[:-4, :] + 16*u[1:-3, :] - 30*u[2:-2, :] + 16*u[3:-1, :] - u[4:, :]) / (12*dy**2)
    
    # Align all arrays to common size
    min_h = min(u_x.shape[0], u_y.shape[0], u_xx.shape[0], u_yy.shape[0], u.shape[0]-4)
    min_w = min(u_x.shape[1], u_y.shape[1], u_xx.shape[1], u_yy.shape[1], u.shape[1]-4)
    
    u_core = u[2:2+min_h, 2:2+min_w]
    u_x = u_x[:min_h, :min_w]
    u_y = u_y[:min_h, :min_w]
    u_xx = u_xx[:min_h, :min_w]
    u_yy = u_yy[:min_h, :min_w]
    
    laplacian = u_xx + u_yy
    
    # Compute higher-order derivatives
    u_xxx = (u_x[:, :-4] - 8*u_x[:, 1:-3] + 8*u_x[:, 3:-1] - u_x[:, 4:]) / (12*dx)
    u_yyy = (u_y[:-4, :] - 8*u_y[1:-3, :] + 8*u_y[3:-1, :] - u_y[4:, :]) / (12*dy)
    u_xxxx = (-u_xx[:, :-4] + 16*u_xx[:, 1:-3] - 30*u_xx[:, 2:-2] + 16*u_xx[:, 3:-1] - u_xx[:, 4:]) / (12*dx**2)
    u_yyyy = (-u_yy[:-4, :] + 16*u_yy[1:-3, :] - 30*u_yy[2:-2, :] + 16*u_yy[3:-1, :] - u_yy[4:, :]) / (12*dy**2)
    
    # Align to smallest size
    min_h2 = min(u_xxx.shape[0], u_yyy.shape[0], u_xxxx.shape[0], u_yyyy.shape[0], u_core.shape[0])
    min_w2 = min(u_xxx.shape[1], u_yyy.shape[1], u_xxxx.shape[1], u_yyyy.shape[1], u_core.shape[1])
    
    u_core = u_core[:min_h2, :min_w2]
    u_x = u_x[:min_h2, :min_w2]
    u_y = u_y[:min_h2, :min_w2]
    u_xx = u_xx[:min_h2, :min_w2]
    u_yy = u_yy[:min_h2, :min_w2]
    laplacian = laplacian[:min_h2, :min_w2]
    u_xxx = u_xxx[:min_h2, :min_w2]
    u_yyy = u_yyy[:min_h2, :min_w2]
    u_xxxx = u_xxxx[:min_h2, :min_w2]
    u_yyyy = u_yyyy[:min_h2, :min_w2]
    
    biharmonic = u_xxxx + u_yyyy
    
    # Build library matching training
    library = [
        np.ones_like(u_core),  # 1
        u_core,                # u
        u_x, u_y,              # u_x, u_y
        u_xx, u_yy, laplacian, # u_xx, u_yy, ‚àá¬≤u
        u_core**2,             # u¬≤
        u_core * u_x,          # u¬∑u_x
        u_core * u_y,          # u¬∑u_y
        u_x**2, u_y**2,        # u_x¬≤, u_y¬≤
        u_core * u_xx,         # u¬∑u_xx
        u_core * u_yy,         # u¬∑u_yy
        u_core * laplacian,    # u¬∑‚àá¬≤u
        u_core**3,             # u¬≥
        u_core**2 * u_x,       # u¬≤¬∑u_x
        u_core**2 * u_y,       # u¬≤¬∑u_y
        u_xxx, u_yyy,          # u_xxx, u_yyy
        u_xxxx, u_yyyy,        # u_xxxx, u_yyyy
        biharmonic,            # ‚àá‚Å¥u
        u_core * u_xxxx,       # u¬∑u_xxxx
        u_core * u_yyyy        # u¬∑u_yyyy
    ]
    
    return np.stack([term.ravel() for term in library], axis=1), u_core.shape


def simulate_pde(u0, coeffs, dx, dy, dt, n_steps):
    """
    Forward Euler integration of discovered PDE.
    u_t = Œò(u) @ coeffs
    """
    u_sim = [u0.copy()]
    u_current = u0.copy()
    
    for step in range(n_steps):
        # Compute library at current state
        library_vec, shape = compute_library_field(u_current, dx, dy)
        
        # Predict du/dt
        dudt = (library_vec @ coeffs).reshape(shape)
        
        # Forward Euler step
        # Need to place dudt into full field (pad edges)
        h_pad = (u_current.shape[0] - dudt.shape[0]) // 2
        w_pad = (u_current.shape[1] - dudt.shape[1]) // 2
        
        u_next = u_current.copy()
        u_next[h_pad:h_pad+dudt.shape[0], w_pad:w_pad+dudt.shape[1]] += dt * dudt
        
        u_sim.append(u_next)
        u_current = u_next
        
    return np.array(u_sim)


# Simulate from first registered frame
print("Simulating discovered PDE forward in time...")
u_initial = U_registered[0]
n_sim_frames = min(20, T-1)  # Simulate 20 steps

U_simulated = simulate_pde(u_initial, coeffs_best, dx, dy, dt, n_sim_frames)

print(f"‚úÖ Simulated {len(U_simulated)} frames")
print(f"   Initial field shape: {u_initial.shape}")
print(f"   Simulation shape: {U_simulated.shape}")

## 14. Spatiotemporal Comparison (PRESENTATION FIGURE 3)

In [None]:
# CREATE SPATIOTEMPORAL COMPARISON FIGURE

# Extract central spatial line for visualization
h_center = U_registered.shape[1] // 2
measured_line = U_registered[:n_sim_frames+1, h_center, :]
simulated_line = U_simulated[:, h_center, :]

# Ensure same size
min_len = min(measured_line.shape[0], simulated_line.shape[0])
measured_line = measured_line[:min_len]
simulated_line = simulated_line[:min_len]

fig = plt.figure(figsize=(18, 12))
gs = fig.add_gridspec(3, 3, hspace=0.35, wspace=0.3)

# 1. Measured spatiotemporal plot
ax1 = fig.add_subplot(gs[0, 0])
im1 = ax1.imshow(measured_line.T, aspect='auto', cmap='RdBu_r', origin='lower')
ax1.set_xlabel('Time Frame', fontsize=12, fontweight='bold')
ax1.set_ylabel('Spatial Position (x)', fontsize=12, fontweight='bold')
ax1.set_title('Measured Data\n(Horizontal Slice)', fontsize=13, fontweight='bold')
plt.colorbar(im1, ax=ax1, label='Intensity')

# 2. Simulated spatiotemporal plot
ax2 = fig.add_subplot(gs[0, 1])
im2 = ax2.imshow(simulated_line.T, aspect='auto', cmap='RdBu_r', origin='lower', 
                 vmin=im1.get_clim()[0], vmax=im1.get_clim()[1])
ax2.set_xlabel('Time Frame', fontsize=12, fontweight='bold')
ax2.set_ylabel('Spatial Position (x)', fontsize=12, fontweight='bold')
ax2.set_title('PDE Simulation\n(Discovered Equation)', fontsize=13, fontweight='bold')
plt.colorbar(im2, ax=ax2, label='Intensity')

# 3. Error/difference
ax3 = fig.add_subplot(gs[0, 2])
error = measured_line - simulated_line
im3 = ax3.imshow(error.T, aspect='auto', cmap='seismic', origin='lower', 
                 vmin=-np.abs(error).max(), vmax=np.abs(error).max())
ax3.set_xlabel('Time Frame', fontsize=12, fontweight='bold')
ax3.set_ylabel('Spatial Position (x)', fontsize=12, fontweight='bold')
ax3.set_title(f'Prediction Error\nRMSE = {np.sqrt(np.mean(error**2)):.4f}', fontsize=13, fontweight='bold')
plt.colorbar(im3, ax=ax3, label='Error')

# 4. Sample spatial snapshots
times = [0, min_len//3, 2*min_len//3, min_len-1]
for i, t in enumerate(times):
    ax = fig.add_subplot(gs[1, i if i < 3 else 0])
    ax.plot(measured_line[t], 'b-', linewidth=2, label='Measured', alpha=0.7)
    ax.plot(simulated_line[t], 'r--', linewidth=2, label='Simulated', alpha=0.7)
    ax.set_xlabel('Spatial Position', fontsize=11, fontweight='bold')
    ax.set_ylabel('Intensity', fontsize=11, fontweight='bold')
    ax.set_title(f'Frame {t}', fontsize=12, fontweight='bold')
    ax.legend(fontsize=9)
    ax.grid(True, alpha=0.3)

# 5. Temporal evolution at fixed spatial points
spatial_points = [measured_line.shape[1]//4, measured_line.shape[1]//2, 3*measured_line.shape[1]//4]
ax5 = fig.add_subplot(gs[2, 0])
for sp in spatial_points:
    ax5.plot(measured_line[:, sp], 'o-', linewidth=1.5, markersize=3, alpha=0.7, label=f'Measured x={sp}')
    ax5.plot(simulated_line[:, sp], 's--', linewidth=1.5, markersize=3, alpha=0.7, label=f'Simulated x={sp}')
ax5.set_xlabel('Time Frame', fontsize=12, fontweight='bold')
ax5.set_ylabel('Intensity', fontsize=12, fontweight='bold')
ax5.set_title('Temporal Evolution\n(Selected Points)', fontsize=13, fontweight='bold')
ax5.legend(fontsize=8, ncol=2)
ax5.grid(True, alpha=0.3)

# 6. Correlation plot
ax6 = fig.add_subplot(gs[2, 1])
sample_idx = np.random.choice(measured_line.size, min(5000, measured_line.size), replace=False)
flat_meas = measured_line.ravel()[sample_idx]
flat_sim = simulated_line.ravel()[sample_idx]
ax6.scatter(flat_meas, flat_sim, alpha=0.3, s=2, c='purple')
lims = [min(flat_meas.min(), flat_sim.min()), max(flat_meas.max(), flat_sim.max())]
ax6.plot(lims, lims, 'r--', linewidth=2, label='Perfect Match')
corr = np.corrcoef(measured_line.ravel(), simulated_line.ravel())[0, 1]
ax6.set_xlabel('Measured Intensity', fontsize=12, fontweight='bold')
ax6.set_ylabel('Simulated Intensity', fontsize=12, fontweight='bold')
ax6.set_title(f'Correlation Plot\nPearson r = {corr:.4f}', fontsize=13, fontweight='bold')
ax6.legend(fontsize=10)
ax6.grid(True, alpha=0.3)

# 7. Error statistics over time
ax7 = fig.add_subplot(gs[2, 2])
rmse_time = np.sqrt(np.mean((measured_line - simulated_line)**2, axis=1))
mae_time = np.mean(np.abs(measured_line - simulated_line), axis=1)
ax7.plot(rmse_time, 'r-', linewidth=2, label='RMSE', marker='o', markersize=4)
ax7.plot(mae_time, 'b-', linewidth=2, label='MAE', marker='s', markersize=4)
ax7.set_xlabel('Time Frame', fontsize=12, fontweight='bold')
ax7.set_ylabel('Error Magnitude', fontsize=12, fontweight='bold')
ax7.set_title('Error Evolution Over Time', fontsize=13, fontweight='bold')
ax7.legend(fontsize=10)
ax7.grid(True, alpha=0.3)

plt.savefig(OUTPUT_FOLDER / 'PRESENTATION_FIG3_Spatiotemporal_Comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print("\n‚úÖ PRESENTATION FIGURE 3 SAVED: PRESENTATION_FIG3_Spatiotemporal_Comparison.png")

## 15. Final Summary and Results (PRESENTATION FIGURE 4)

In [None]:
# CREATE FINAL SUMMARY FIGURE WITH DISCOVERED EQUATION

fig = plt.figure(figsize=(18, 10))
gs = fig.add_gridspec(2, 3, hspace=0.4, wspace=0.3)

# 1. Discovered PDE equation (large text box)
ax1 = fig.add_subplot(gs[0, :])
ax1.axis('off')

# Build equation string
eq_parts = []
for i, (coeff, name) in enumerate(zip(coeffs_best, term_names)):
    if np.abs(coeff) > best_result['threshold']:
        sign = '+' if coeff > 0 and len(eq_parts) > 0 else ''
        eq_parts.append(f"{sign}{coeff:.6f}¬∑{name}")

equation_str = "u_t = " + " ".join(eq_parts) if eq_parts else "u_t = 0"

# Create text box
textstr = f"""
‚ïî‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïó
‚ïë                        DISCOVERED PDE EQUATION                           ‚ïë
‚ïö‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïù

{equation_str}

Model Performance:
  ‚Ä¢ R¬≤ Score: {best_result['r2']:.6f}
  ‚Ä¢ Active Terms: {best_result['n_active']} / {len(term_names)}
  ‚Ä¢ Sparsity: {(1 - best_result['n_active']/len(term_names))*100:.1f}%
  ‚Ä¢ STRidge Parameters: Œ±={best_result['alpha']}, Œª={best_result['threshold']}
  
Physical Interpretation:
  ‚Ä¢ Linear terms: Describe growth/decay and diffusion
  ‚Ä¢ Nonlinear terms: Capture amplitude-dependent dynamics
  ‚Ä¢ High-order terms (u_xxxx, u_yyyy): Indicate Kuramoto-Sivashinsky-type dynamics
  ‚Ä¢ Biharmonic operator (‚àá‚Å¥u): Suggests pattern-forming instabilities
"""

ax1.text(0.5, 0.5, textstr, fontsize=11, family='monospace', 
         ha='center', va='center', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))

# 2. Residual histogram with statistics
ax2 = fig.add_subplot(gs[1, 0])
ax2.hist(residuals, bins=80, alpha=0.7, edgecolor='black', color='teal', density=True)
ax2.axvline(residuals.mean(), color='red', linestyle='--', linewidth=2, label=f'Mean = {residuals.mean():.2e}')
ax2.axvline(residuals.median(), color='orange', linestyle='--', linewidth=2, label=f'Median = {residuals.median():.2e}')

# Fit normal distribution
from scipy.stats import norm
mu, std = norm.fit(residuals)
xmin, xmax = ax2.get_xlim()
x = np.linspace(xmin, xmax, 100)
p = norm.pdf(x, mu, std)
ax2.plot(x, p, 'k-', linewidth=2, label=f'Normal (Œº={mu:.2e}, œÉ={std:.2e})')

ax2.set_xlabel('Residual Value', fontsize=12, fontweight='bold')
ax2.set_ylabel('Probability Density', fontsize=12, fontweight='bold')
ax2.set_title(f'Residual Distribution Analysis\nSkewness = {np.mean((residuals - mu)**3) / std**3:.3f}', 
              fontsize=12, fontweight='bold')
ax2.legend(fontsize=9)
ax2.grid(True, alpha=0.3, axis='y')

# 3. QQ plot for normality check
ax3 = fig.add_subplot(gs[1, 1])
from scipy.stats import probplot
probplot(residuals, dist="norm", plot=ax3)
ax3.set_xlabel('Theoretical Quantiles', fontsize=12, fontweight='bold')
ax3.set_ylabel('Sample Quantiles', fontsize=12, fontweight='bold')
ax3.set_title('Q-Q Plot\n(Normality Check)', fontsize=12, fontweight='bold')
ax3.grid(True, alpha=0.3)

# 4. Key metrics comparison table
ax4 = fig.add_subplot(gs[1, 2])
ax4.axis('off')

metrics_data = [
    ['Metric', 'Value'],
    ['‚îÄ' * 25, '‚îÄ' * 15],
    ['R¬≤ Score', f'{best_result["r2"]:.6f}'],
    ['RMSE', f'{np.sqrt(np.mean(residuals**2)):.6f}'],
    ['MAE', f'{np.mean(np.abs(residuals)):.6f}'],
    ['Max Error', f'{np.abs(residuals).max():.6f}'],
    ['Correlation (r)', f'{corr:.6f}'],
    ['‚îÄ' * 25, '‚îÄ' * 15],
    ['Total Terms', str(len(term_names))],
    ['Active Terms', str(best_result['n_active'])],
    ['Sparsity', f'{(1-best_result["n_active"]/len(term_names))*100:.1f}%'],
    ['‚îÄ' * 25, '‚îÄ' * 15],
    ['Image Frames', str(T)],
    ['Spatial Resolution', f'{H} √ó {W}'],
    ['Grid Spacing', f'dx={dx}, dy={dy}, dt={dt}'],
    ['‚îÄ' * 25, '‚îÄ' * 15],
    ['Registration', 'Subpixel Optical Flow'],
    ['Smoothing', 'Savitzky-Golay (7, 3)'],
    ['Derivatives', '4th-order Finite Diff.'],
    ['Solver', 'STRidge'],
]

table_text = '\n'.join([f'{row[0]:.<25s} {row[1]:>15s}' for row in metrics_data])

ax4.text(0.1, 0.5, table_text, fontsize=10, family='monospace', va='center',
         bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.7))
ax4.set_title('Summary Statistics & Configuration', fontsize=13, fontweight='bold', pad=20)

plt.savefig(OUTPUT_FOLDER / 'PRESENTATION_FIG4_Summary_Results.png', dpi=300, bbox_inches='tight')
plt.show()

print("\n‚úÖ PRESENTATION FIGURE 4 SAVED: PRESENTATION_FIG4_Summary_Results.png")
print("\n" + "="*80)
print("üéâ ALL PRESENTATION FIGURES COMPLETE!")
print("="*80)
print(f"\nGenerated files in '{OUTPUT_FOLDER}':")
print("  1. PRESENTATION_FIG1_Registration_Quality.png")
print("  2. PRESENTATION_FIG2_Model_Performance.png")
print("  3. PRESENTATION_FIG3_Spatiotemporal_Comparison.png")
print("  4. PRESENTATION_FIG4_Summary_Results.png")
print("\n" + "="*80)