In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

def analyze_real_images(image_folder):
    """Analyze the characteristics of your real laser-matter images"""
    image_paths = sorted(Path(image_folder).glob("*.tif"))  
    images = []
    
    for path in image_paths:
        img = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)  # Or RGB if needed
        images.append(img)
    
    images = np.array(images)  # Shape: (n_time, height, width)
    
    print(f"Total images: {len(images)}")
    print(f"Image shape: {images[0].shape}")
    print(f"Time dimension: {len(images)}")
    
    # Analyze statistics
    print("\n--- Statistics ---")
    print(f"Min intensity: {images.min()}")
    print(f"Max intensity: {images.max()}")
    print(f"Mean intensity: {images.mean():.2f}")
    print(f"Std intensity: {images.std():.2f}")
    
    # Check if images are registered
    print("\n--- Registration Check ---")
    diff_means = []
    for i in range(len(images)-1):
        diff = np.abs(images[i+1] - images[i])
        diff_means.append(diff.mean())
    
    print(f"Mean absolute difference between frames: {np.mean(diff_means):.2f}")

    
    # Visualize sample frames
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    for i, ax in enumerate(axes.flat):
        if i < len(images):
            im = ax.imshow(images[i],  cmap='gray')
            ax.set_title(f'Frame {i}')
            plt.colorbar(im, ax=ax)
    
    # Visualize differences
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].imshow(images[31],  cmap='gray')
    axes[0].set_title('Frame 31')
    axes[1].imshow(images[32],  cmap='gray')
    axes[1].set_title('Frame 32')
    axes[2].imshow(np.abs(images[32] - images[31]),  cmap='gray')
    axes[2].set_title('Absolute Difference')
    
    return images

# Run analysis'
images = analyze_real_images("C:/Users/Sinjini/Documents/Semester 3/MLDM Project/Real_Images")

In [None]:
diff_mse = []
for i in range(len(images)-1):
    diff = images[i+1].astype(np.float32) - images[i].astype(np.float32)
    diff_mse.append(np.mean(diff**2))
print("Mean squared difference between frames:", np.mean(diff_mse))


In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt

# images already loaded: shape (T, H, W), uint8 or similar

ref = images[0].astype(np.uint8)   # reference frame
mean_disp = []

for i in range(1, len(images)):
    cur = images[i].astype(np.uint8)

    flow = cv2.calcOpticalFlowFarneback(
        ref, cur, None,
        0.5, 3, 15, 3, 5, 1.2, 0
    )
    mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1])
    mean_disp.append(mag.mean())

frames = np.arange(1, len(images))

plt.figure(figsize=(6,4))
plt.plot(frames, mean_disp, marker='o')
plt.xlabel("Frame index")
plt.ylabel("Mean displacement from frame 0 (pixels)")
plt.title("Average displacement vs frame")
plt.grid(True)
plt.show()


In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

def analyze_and_register(image_folder):
    # 1. Load Images
    image_paths = sorted(Path(image_folder).glob("*.tif"))
    images = [cv2.imread(str(p), cv2.IMREAD_GRAYSCALE) for p in image_paths]
    images = np.array(images)
    
    # 2. Setup Reference (Frame 0)
    ref_img = images[0]
    height, width = ref_img.shape
    
    # Initialize storage for motion data
    translations_x = [0]
    translations_y = [0]
    rotations = [0] # In degrees
    
    registered_images = [ref_img] # List to store corrected images
    
    # Initialize ORB detector (Robust feature detection)
    orb = cv2.ORB_create(nfeatures=2000)
    
    # Find keypoints in the clean Frame 0
    kp1, des1 = orb.detectAndCompute(ref_img, None)
    
    print(f"Processing {len(images)} frames...")
    
    # 3. Loop through frames to Calculate & Correct
    for i in range(1, len(images)):
        curr_img = images[i]
        
        # Detect features in current frame
        kp2, des2 = orb.detectAndCompute(curr_img, None)
        
        # Match features
        matcher = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
        matches = matcher.match(des1, des2)
        
        # Extract location of good matches
        src_pts = np.float32([kp1[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2)
        dst_pts = np.float32([kp2[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2)
        
        # --- DIAGNOSIS STEP ---
        # Estimate a "Partial Affine" transform (Rotation + Translation + Scale)
        # RANSAC will ignore the chaotic laser pixels
        matrix, inliers = cv2.estimateAffinePartial2D(dst_pts, src_pts, method=cv2.RANSAC)
        
        if matrix is None:
            # Fallback if detection fails (rare): assume no movement
            matrix = np.array([[1, 0, 0], [0, 1, 0]], dtype=np.float32)
            
        # Extract components for analysis
        # Matrix is [[cos, -sin, tx], [sin, cos, ty]]
        tx = matrix[0, 2]
        ty = matrix[1, 2]
        
        # Calculate rotation angle from cosine/sine
        # arctan2(sin, cos)
        angle_rad = np.arctan2(matrix[1, 0], matrix[0, 0])
        angle_deg = np.degrees(angle_rad)
        
        translations_x.append(tx)
        translations_y.append(ty)
        rotations.append(angle_deg)
        
        # --- CORRECTION STEP ---
        # Apply the matrix to warp the current image back to match Frame 0
        #registered_img = cv2.warpAffine(curr_img, matrix, (width, height))
        #registered_images.append(registered_img)

    # 4. Visualization of the Diagnosis
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), sharex=True)
    
    # Plot Translation
    ax1.plot(translations_x, label='X Shift (pixels)', marker='.')
    ax1.plot(translations_y, label='Y Shift (pixels)', marker='.')
    ax1.set_title("Diagnosis: Detected Translation")
    ax1.set_ylabel("Pixels")
    ax1.legend()
    ax1.grid(True)
    
    # Plot Rotation
    ax2.plot(rotations, color='r', marker='.')
    ax2.set_title("Diagnosis: Detected Rotation")
    ax2.set_ylabel("Degrees")
    ax2.set_xlabel("Frame Index")
    ax2.grid(True)
    
    plt.tight_layout()
    plt.show()
    
    #return np.array(registered_images)

# --- Run the Code ---
# Replace with your actual path
folder_path = "C:/Users/Sinjini/Documents/Semester 3/MLDM Project/Real_Images"
#clean_images = analyze_and_register(folder_path)
analyze_and_register(folder_path)
#print(f"Registered data shape: {clean_images.shape}")

'''
# Optional: Visualize Before vs After for Frame 15 (or any active frame)
#import matplotlib.pyplot as plt
raw_images = np.array([cv2.imread(str(p), cv2.IMREAD_GRAYSCALE) for p in sorted(Path(folder_path).glob("*.tif"))])


frame_idx = 37 # pick a frame where the movement is high
plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.imshow(raw_images[frame_idx], cmap='gray')
plt.title(f"Original Frame {frame_idx}")
plt.subplot(1,2,2)
plt.imshow(clean_images[frame_idx], cmap='gray')
plt.title(f"Registered Frame {frame_idx}")
plt.show()
'''

In [None]:
import numpy as np
import cv2
from scipy.ndimage import gaussian_filter

# Assuming 'images' is your (51, H, W) numpy array of grayscale images
# We will check Frame 15 (arbitrary choice)
frame_idx = 0
raw_frame = images[frame_idx]

# 1. Function to calculate 2nd spatial derivative (u_xx) using finite difference
def calculate_uxx(img):
    # The u_xx kernel: [1, -2, 1]
    uxx = cv2.filter2D(img.astype(np.float32), -1, np.array([[1, -2, 1]], dtype=np.float32))
    return uxx

# --- RAW DATA DIAGNOSIS ---
raw_uxx = calculate_uxx(raw_frame)
raw_uxx_variance = np.var(raw_uxx)
print(f"RAW u_xx Variance: {raw_uxx_variance:.2f}")

# --- SMOOTHED DATA DIAGNOSIS ---
# Apply a mild Gaussian filter (simulating denoising)
smoothed_frame = gaussian_filter(raw_frame, sigma=1)
smoothed_uxx = calculate_uxx(smoothed_frame)
smoothed_uxx_variance = np.var(smoothed_uxx)
print(f"SMOOTHED u_xx Variance: {smoothed_uxx_variance:.2f}")

# --- VISUAL PROOF ---
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Display the magnitude of the derivatives
axes[0].imshow(np.abs(raw_uxx), cmap='hot', vmax=np.percentile(np.abs(raw_uxx), 98))
axes[0].set_title(f'Raw |u_xx| (Variance: {raw_uxx_variance:.0f})')
axes[1].imshow(np.abs(smoothed_uxx), cmap='hot', vmax=np.percentile(np.abs(smoothed_uxx), 98))
axes[1].set_title(f'Smoothed |u_xx| (Variance: {smoothed_uxx_variance:.0f})')
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cv2

# 'images' is your loaded (51, H, W) numpy array

for i in range(len(images)):
    frame_i = images[i]
    
    # 1. Estimate the pure noise component.
    # A simple way to estimate noise is by subtracting a heavily blurred version 
    # from the original image. The difference should isolate the high-frequency noise.
    blurred_frame = cv2.GaussianBlur(frame_i, (5, 5), 0)
    noise_estimate = frame_i.astype(np.float32) - blurred_frame.astype(np.float32)
    
    # 2. Flatten and normalize the noise array
    noise_flat = noise_estimate.flatten()
    
    # 3. Plot the histogram of the noise
    plt.figure(figsize=(7, 5))
    plt.hist(noise_flat, bins=50, density=True, color='gray', alpha=0.7)
    
    # Fit a Gaussian curve for visual comparison
    from scipy.stats import norm
    xmin, xmax = plt.xlim()
    x = np.linspace(xmin, xmax, 100)
    mu, std = norm.fit(noise_flat)
    p = norm.pdf(x, mu, std)
    
    plt.plot(x, p, 'r', linewidth=2, label=f'Gaussian Fit (μ={mu:.2f}, σ={std:.2f})')
    plt.title(f"Histogram of Estimated Noise in Frame {i}")
    plt.xlabel("Pixel Value Residual (Noise)")
    plt.ylabel("Frequency (Normalized)")
    plt.legend()
    plt.grid(True)
    plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cv2

def poisson_check(images, frame_index=15, block_size=10):
    """
    Performs the Variance vs. Mean analysis to check for Poisson noise.
    
    Args:
        images (np.ndarray): The (T, H, W) stack of grayscale images.
        frame_index (int): Index of the brightest frame to analyze (e.g., 40).
        block_size (int): Size of the square block to analyze (e.g., 10x10 pixels).
    """
    
    # 1. Use an active, bright frame to capture potential shot noise
    active_frame = images[frame_index].astype(np.float32)
    
    # 2. Initialize lists to store mean and variance of each block
    means = []
    variances = []
    
    height, width = active_frame.shape
    
    # 3. Loop through the image in non-overlapping blocks
    for y in range(0, height, block_size):
        for x in range(0, width, block_size):
            # Define the block
            block = active_frame[y:y+block_size, x:x+block_size]
            
            # Ensure the block is full size (avoids edge effects if image size isn't divisible)
            if block.shape[0] == block_size and block.shape[1] == block_size:
                
                # Calculate mean and variance for the block
                block_mean = np.mean(block)
                block_variance = np.var(block)
                
                means.append(block_mean)
                variances.append(block_variance)

    # 4. Create the Diagnostic Plot
    plt.figure(figsize=(7, 5))
    plt.scatter(means, variances, s=15, alpha=0.6, color='blue')
    
    # Calculate and plot a linear trend line for easy visualization
    coefficients = np.polyfit(means, variances, 1)
    poly_fn = np.poly1d(coefficients)
    plt.plot(means, poly_fn(means), "r-", 
             label=f'Linear Fit (Slope: {coefficients[0]:.3f})')
    
    plt.title(f"Noise Variance vs. Signal Mean (Frame {frame_index})")
    plt.xlabel("Local Mean Intensity (Signal)")
    plt.ylabel("Local Variance (Noise Power)")
    plt.legend()
    plt.grid(True)
    plt.show()
    
    return coefficients[0] # Return the slope of the fit

# --- EXECUTE THE CHECK ---

# You need to run this command:
slope = poisson_check(images) 
print(f"Slope of the fit: {slope:.3f}")

In [None]:
from skimage.restoration import denoise_tv_chambolle



def preprocess_images_corrected(images_raw):

    #CORRECTED denoising pipeline with gentler TV weight.
    
    T, H, W = images_raw.shape
    images_clean = np.zeros_like(images_raw, dtype=np.float32)
    
    for t in range(T):
        img = images_raw[t].astype(np.float32)
        
        # Step 1: Anscombe transform
        img_anscombe = 2.0 * np.sqrt(img + 3.0/8.0)
        
        # Step 2: MUCH GENTLER TV denoising
        # Try these weights in order: 0.01, 0.005, 0.001
        img_denoised = denoise_tv_chambolle(img_anscombe, weight=0.01)
        
        # Step 3: Inverse Anscombe
        img_clean = (img_denoised / 2.0)**2 - 3.0/8.0
        img_clean = np.clip(img_clean, 0, 255)
        
        images_clean[t] = img_clean
    
    # Normalize to [0, 1]
    #images_clean = (images_clean - images_clean.min()) / (images_clean.max() - images_clean.min())
    
    return images_clean


In [None]:
def preprocess_images_gaussian(images_raw):
    T, H, W = images_raw.shape
    images_clean = np.zeros_like(images_raw, dtype=np.float32)
    
    for t in range(T):
        img = images_raw[t].astype(np.float32)
        
        # Anscombe
        img_anscombe = 2.0 * np.sqrt(img + 3.0/8.0)
        
        # Gaussian smoothing (no TV)
        img_denoised = gaussian_filter(img_anscombe, sigma=1.0)
        
        # Inverse Anscombe
        img_clean = (img_denoised / 2.0)**2 - 3.0/8.0
        img_clean = np.clip(img_clean, 0, 255)
        
        images_clean[t] = img_clean
    
    #images_clean = (images_clean - images_clean.min()) / (images_clean.max() - images_clean.min())
    return images_clean


images_clean = preprocess_images_gaussian(images)

#images_clean = preprocess_images_corrected(images)

In [None]:
for i in range(len(images_clean[0:2])):
    frame_i = images_clean[i]
    
    # 1. Estimate the pure noise component.
    # A simple way to estimate noise is by subtracting a heavily blurred version 
    # from the original image. The difference should isolate the high-frequency noise.
    blurred_frame = cv2.GaussianBlur(frame_i, (5, 5), 0)
    noise_estimate = frame_i.astype(np.float32) - blurred_frame.astype(np.float32)
    
    # 2. Flatten and normalize the noise array
    noise_flat = noise_estimate.flatten()
    
    # 3. Plot the histogram of the noise
    plt.figure(figsize=(7, 5))
    plt.hist(noise_flat, bins=50, density=True, color='gray', alpha=0.7)
    
    # Fit a Gaussian curve for visual comparison
    from scipy.stats import norm
    xmin, xmax = plt.xlim()
    x = np.linspace(xmin, xmax, 100)
    mu, std = norm.fit(noise_flat)
    p = norm.pdf(x, mu, std)
    
    plt.plot(x, p, 'r', linewidth=2, label=f'Gaussian Fit (μ={mu:.2f}, σ={std:.2f})')
    plt.title(f"Histogram of Estimated Noise in Frame {i}")
    plt.xlabel("Pixel Value Residual (Noise)")
    plt.ylabel("Frequency (Normalized)")
    plt.legend()
    plt.grid(True)
    plt.show()

In [None]:
def poisson_check(images, frame_index=15, block_size=10):
    """
    Performs the Variance vs. Mean analysis to check for Poisson noise.
    
    Args:
        images (np.ndarray): The (T, H, W) stack of grayscale images.
        frame_index (int): Index of the brightest frame to analyze (e.g., 40).
        block_size (int): Size of the square block to analyze (e.g., 10x10 pixels).
    """
    
    # 1. Use an active, bright frame to capture potential shot noise
    active_frame = images[frame_index].astype(np.float32)
    
    # 2. Initialize lists to store mean and variance of each block
    means = []
    variances = []
    
    height, width = active_frame.shape
    
    # 3. Loop through the image in non-overlapping blocks
    for y in range(0, height, block_size):
        for x in range(0, width, block_size):
            # Define the block
            block = active_frame[y:y+block_size, x:x+block_size]
            
            # Ensure the block is full size (avoids edge effects if image size isn't divisible)
            if block.shape[0] == block_size and block.shape[1] == block_size:
                
                # Calculate mean and variance for the block
                block_mean = np.mean(block)
                block_variance = np.var(block)
                
                means.append(block_mean)
                variances.append(block_variance)

    # 4. Create the Diagnostic Plot
    plt.figure(figsize=(7, 5))
    plt.scatter(means, variances, s=15, alpha=0.6, color='blue')
    
    # Calculate and plot a linear trend line for easy visualization
    coefficients = np.polyfit(means, variances, 1)
    poly_fn = np.poly1d(coefficients)
    plt.plot(means, poly_fn(means), "r-", 
             label=f'Linear Fit (Slope: {coefficients[0]:.3f})')
    
    plt.title(f"Noise Variance vs. Signal Mean (Frame {frame_index})")
    plt.xlabel("Local Mean Intensity (Signal)")
    plt.ylabel("Local Variance (Noise Power)")
    plt.legend()
    plt.grid(True)
    plt.show()
    
    return coefficients[0] # Return the slope of the fit

# --- EXECUTE THE CHECK ---

# You need to run this command:
slope = poisson_check(images_clean) 
print(f"Slope of the fit: {slope:.3f}")

In [None]:
import numpy as np
import cv2
from scipy.ndimage import gaussian_filter

# Assuming 'images' is your (51, H, W) numpy array of grayscale images
# We will check Frame 15 (arbitrary choice)
frame_idx = 45
raw_frame = images_clean[frame_idx]

# 1. Function to calculate 2nd spatial derivative (u_xx) using finite difference
def calculate_uxx(img):
    # The u_xx kernel: [1, -2, 1]
    uxx = cv2.filter2D(img.astype(np.float32), -1, np.array([[1, -2, 1]], dtype=np.float32))
    return uxx

# --- RAW DATA DIAGNOSIS ---
raw_uxx = calculate_uxx(raw_frame)
raw_uxx_variance = np.var(raw_uxx)
print(f"RAW u_xx Variance: {raw_uxx_variance:.2f}")

# --- SMOOTHED DATA DIAGNOSIS ---
# Apply a mild Gaussian filter (simulating denoising)
smoothed_frame = gaussian_filter(raw_frame, sigma=1)
smoothed_uxx = calculate_uxx(smoothed_frame)
smoothed_uxx_variance = np.var(smoothed_uxx)
print(f"SMOOTHED u_xx Variance: {smoothed_uxx_variance:.2f}")

# --- VISUAL PROOF ---
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Display the magnitude of the derivatives
axes[0].imshow(np.abs(raw_uxx), cmap='hot', vmax=np.percentile(np.abs(raw_uxx), 98))
axes[0].set_title(f'Raw |u_xx| (Variance: {raw_uxx_variance:.0f})')
axes[1].imshow(np.abs(smoothed_uxx), cmap='hot', vmax=np.percentile(np.abs(smoothed_uxx), 98))
axes[1].set_title(f'Smoothed |u_xx| (Variance: {smoothed_uxx_variance:.0f})')
plt.show()

In [None]:
import matplotlib.pyplot as plt

# Display frames 0, 25, 50 side by side
fig, axes = plt.subplots(2, 3, figsize=(15, 8))

# Raw images
axes[0, 0].imshow(images[0], cmap='gray')
axes[0, 0].set_title('Raw Frame 0')
axes[0, 1].imshow(images[25], cmap='gray')
axes[0, 1].set_title('Raw Frame 25')
axes[0, 2].imshow(images[50], cmap='gray')
axes[0, 2].set_title('Raw Frame 50')

# Clean images
axes[1, 0].imshow(images_clean[0], cmap='gray')
axes[1, 0].set_title('Clean Frame 0')
axes[1, 1].imshow(images_clean[25], cmap='gray')
axes[1, 1].set_title('Clean Frame 25')
axes[1, 2].imshow(images_clean[50], cmap='gray')
axes[1, 2].set_title('Clean Frame 50')

plt.tight_layout()
plt.savefig('visual_check.png', dpi=150)
plt.show()

# Print statistics
print("Raw images:")
print(f"  Frame 0:  mean={images[0].mean():.2f}, std={images[0].std():.2f}")
print(f"  Frame 25: mean={images[25].mean():.2f}, std={images[25].std():.2f}")
print(f"  Frame 50: mean={images[50].mean():.2f}, std={images[50].std():.2f}")

print("\nCleaned images:")
print(f"  Frame 0:  mean={images_clean[0].mean():.2f}, std={images_clean[0].std():.2f}")
print(f"  Frame 25: mean={images_clean[25].mean():.2f}, std={images_clean[25].std():.2f}")
print(f"  Frame 50: mean={images_clean[50].mean():.2f}, std={images_clean[50].std():.2f}")

In [None]:
from skimage.restoration import denoise_tv_chambolle
import numpy as np

def preprocess_images_optimized_tv(images_raw):
    """
    Optimized denoising pipeline running TV filter on the Anscombe-transformed data.
    The weight is increased to effectively suppress the noise in the transformed domain.
    """
    
    T, H, W = images_raw.shape
    images_clean = np.zeros_like(images_raw, dtype=np.float32)
    
    # We must operate on float data
    images_raw = images_raw.astype(np.float32)
    
    # --- Recommended weights to try in order: 0.15, 0.20, 0.25 ---
    # We start with 0.15, which is a good balance for typical image processing tasks.
    TV_WEIGHT = 0.15 
    
    print(f"Applying Anscombe + TV Denoising with Weight={TV_WEIGHT}...")

    for t in range(T):
        img = images_raw[t]
        
        # Step 1: Anscombe transform (Variance Stabilization)
        # Transformed data is roughly in the range [1, 35]
        img_anscombe = 2.0 * np.sqrt(img + 3.0/8.0)
        
        # Step 2: TV denoising 
        # The weight is now strong enough to filter noise in the [1, 35] domain.
        img_denoised = denoise_tv_chambolle(
            img_anscombe, 
            weight=TV_WEIGHT, 
            #multichannel=False
        )
        
        # Step 3: Inverse Anscombe (Scale Restoration)
        img_clean = (img_denoised / 2.0)**2 - 3.0/8.0
        img_clean = np.clip(img_clean, 0, 255)
        
        images_clean[t] = img_clean
        
    # DO NOT NORMALIZE HERE. The data must remain in its physical [0, 255] range 
    # for the derivatives to be physically meaningful.
    print("Denoising complete. Data is in [0, 255] range.")
    return images_clean


images_more_clean = preprocess_images_optimized_tv(images_clean)

In [None]:
import numpy as np
import cv2
from scipy.ndimage import gaussian_filter


frame_idx = 0
raw_frame = images_more_clean[frame_idx]

# 1. Function to calculate 2nd spatial derivative (u_xx) using finite difference
def calculate_uxx(img):
    # The u_xx kernel: [1, -2, 1]
    uxx = cv2.filter2D(img.astype(np.float32), -1, np.array([[1, -2, 1]], dtype=np.float32))
    return uxx

# --- RAW DATA DIAGNOSIS ---
raw_uxx = calculate_uxx(raw_frame)
raw_uxx_variance = np.var(raw_uxx)
print(f"RAW u_xx Variance: {raw_uxx_variance:.2f}")

# --- SMOOTHED DATA DIAGNOSIS ---
# Apply a mild Gaussian filter (simulating denoising)
smoothed_frame = gaussian_filter(raw_frame, sigma=1)
smoothed_uxx = calculate_uxx(smoothed_frame)
smoothed_uxx_variance = np.var(smoothed_uxx)
print(f"SMOOTHED u_xx Variance: {smoothed_uxx_variance:.2f}")

# --- VISUAL PROOF ---
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Display the magnitude of the derivatives
axes[0].imshow(np.abs(raw_uxx), cmap='hot', vmax=np.percentile(np.abs(raw_uxx), 98))
axes[0].set_title(f'Raw |u_xx| (Variance: {raw_uxx_variance:.0f})')
axes[1].imshow(np.abs(smoothed_uxx), cmap='hot', vmax=np.percentile(np.abs(smoothed_uxx), 98))
axes[1].set_title(f'Smoothed |u_xx| (Variance: {smoothed_uxx_variance:.0f})')
plt.show()

In [None]:
def compute_derivatives(images_clean, dx=1.0, dy=1.0, dt=1.0):
    """
    Compute spatial and temporal derivatives for SINDy.
    
    Args:
        images_clean: (T, H, W) denoised image stack in [0, 255]
        dx, dy: spatial grid spacing (use 1.0 if unknown)
        dt: time step between frames (use 1.0 if unknown, can rescale later)
    
    Returns:
        u_t, u_x, u_y, u_xx, u_yy: derivative arrays
    """
    T, H, W = images_clean.shape
    
    # 1. Temporal derivative (∂u/∂t) - forward difference
    u_t = np.zeros_like(images_clean)
    u_t[:-1] = (images_clean[1:] - images_clean[:-1]) / dt
    # Note: last frame has u_t = 0 (can't compute forward difference)
    
    # 2. First spatial derivatives (∂u/∂x, ∂u/∂y) - central difference
    u_x = np.zeros_like(images_clean)
    u_y = np.zeros_like(images_clean)
    
    u_x[:, :, 1:-1] = (images_clean[:, :, 2:] - images_clean[:, :, :-2]) / (2*dx)
    u_y[:, 1:-1, :] = (images_clean[:, 2:, :] - images_clean[:, :-2, :]) / (2*dy)
    
    # 3. Second spatial derivatives (∂²u/∂x², ∂²u/∂y²) - central difference
    u_xx = np.zeros_like(images_clean)
    u_yy = np.zeros_like(images_clean)
    
    u_xx[:, :, 1:-1] = (images_clean[:, :, 2:] - 2*images_clean[:, :, 1:-1] + 
                         images_clean[:, :, :-2]) / (dx**2)
    u_yy[:, 1:-1, :] = (images_clean[:, 2:, :] - 2*images_clean[:, 1:-1, :] + 
                         images_clean[:, :-2, :]) / (dy**2)
    
    return u_t, u_x, u_y, u_xx, u_yy


# Compute derivatives
print("="*60)
print("COMPUTING DERIVATIVES")
print("="*60)

u_t, u_x, u_y, u_xx, u_yy = compute_derivatives(
    images_more_clean, 
    dx=1.0,  # pixel spacing (or physical spacing if known)
    dy=1.0, 
    dt=1.0   # time step (or physical time if known)
)

print(f"\n✅ Derivatives computed successfully!")
print(f"\nShapes:")
print(f"  u_t:  {u_t.shape}")
print(f"  u_x:  {u_x.shape}")
print(f"  u_xx: {u_xx.shape}")

print(f"\nRanges (checking for reasonable values):")
print(f"  u_t:  [{u_t.min():.3f}, {u_t.max():.3f}]")
print(f"  u_x:  [{u_x.min():.3f}, {u_x.max():.3f}]")
print(f"  u_y:  [{u_y.min():.3f}, {u_y.max():.3f}]")
print(f"  u_xx: [{u_xx.min():.3f}, {u_xx.max():.3f}]")
print(f"  u_yy: [{u_yy.min():.3f}, {u_yy.max():.3f}]")

# Visual check of derivatives
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 3, figsize=(15, 8))

frame_idx = 25  # Check active frame

axes[0, 0].imshow(images_more_clean[frame_idx], cmap='gray')
axes[0, 0].set_title('Original (Frame 25)')

axes[0, 1].imshow(u_t[frame_idx], cmap='RdBu', vmin=-20, vmax=20)
axes[0, 1].set_title('∂u/∂t (Temporal)')

axes[0, 2].imshow(u_x[frame_idx], cmap='RdBu', vmin=-10, vmax=10)
axes[0, 2].set_title('∂u/∂x (Spatial)')

axes[1, 0].imshow(u_y[frame_idx], cmap='RdBu', vmin=-10, vmax=10)
axes[1, 0].set_title('∂u/∂y (Spatial)')

axes[1, 1].imshow(u_xx[frame_idx], cmap='RdBu', vmin=-5, vmax=5)
axes[1, 1].set_title('∂²u/∂x² (Laplacian X)')

axes[1, 2].imshow(u_yy[frame_idx], cmap='RdBu', vmin=-5, vmax=5)
axes[1, 2].set_title('∂²u/∂y² (Laplacian Y)')

plt.tight_layout()
plt.savefig('derivatives_visualization.png', dpi=150)
plt.show()

print("\n✅ Derivative visualization saved to 'derivatives_visualization.png'")

In [None]:
def validate_derivatives(u_t, u_xx, u_yy, frame_idx=25):
    """
    Visual inspection of derivative quality.
    """
    import matplotlib.pyplot as plt
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    axes[0].imshow(u_t[frame_idx], cmap='RdBu', vmin=-0.1, vmax=0.1)
    axes[0].set_title('∂u/∂t')
    
    axes[1].imshow(u_xx[frame_idx], cmap='RdBu', vmin=-0.1, vmax=0.1)
    axes[1].set_title('∂²u/∂x²')
    
    axes[2].imshow(u_yy[frame_idx], cmap='RdBu', vmin=-0.1, vmax=0.1)
    axes[2].set_title('∂²u/∂y²')
    
    plt.tight_layout()
    plt.savefig('derivative_quality_check.png', dpi=150)
    plt.close()
    
    print(f"u_t range: [{u_t.min():.4f}, {u_t.max():.4f}]")
    print(f"u_xx range: [{u_xx.min():.4f}, {u_xx.max():.4f}]")


validate_derivatives(u_t, u_xx, u_yy)

## The time derivatives are noisy. So we need to do 3D denoising (denoising of the time component as well). So far, we only did denoising of x and y components.

In [None]:
from scipy.ndimage import gaussian_filter

def preprocess_images_spatiotemporal(images_raw):
    """
    Denoising with BOTH spatial AND temporal smoothing.
    """
    T, H, W = images_raw.shape
    
    # Step 1: Anscombe transform (frame-by-frame)
    images_anscombe = np.zeros_like(images_raw, dtype=np.float32)
    for t in range(T):
        img = images_raw[t].astype(np.float32)
        images_anscombe[t] = 2.0 * np.sqrt(img + 3.0/8.0)
    
    # Step 2: 3D Gaussian smoothing (spatial + temporal)
    # sigma = (temporal_sigma, spatial_sigma_y, spatial_sigma_x)
    images_denoised = gaussian_filter(
        images_anscombe, 
        sigma=(1.0, 1.0, 1.0)  # Smooth in time and space
    )
    
    # Step 3: Inverse Anscombe transform (frame-by-frame)
    images_clean = np.zeros_like(images_raw, dtype=np.float32)
    for t in range(T):
        img_clean = (images_denoised[t] / 2.0)**2 - 3.0/8.0
        img_clean = np.clip(img_clean, 0, 255)
        images_clean[t] = img_clean
    
    return images_clean


# Apply 3D denoising
print("Applying spatiotemporal denoising...")
images_clean_3d = preprocess_images_spatiotemporal(images)






In [None]:
frame_idx = 35
raw_frame = images_clean_3d[frame_idx]

# 1. Function to calculate 2nd spatial derivative (u_xx) using finite difference
def calculate_uxx(img):
    # The u_xx kernel: [1, -2, 1]
    uxx = cv2.filter2D(img.astype(np.float32), -1, np.array([[1, -2, 1]], dtype=np.float32))
    return uxx

# --- RAW DATA DIAGNOSIS ---
raw_uxx = calculate_uxx(raw_frame)
raw_uxx_variance = np.var(raw_uxx)
print(f"RAW u_xx Variance: {raw_uxx_variance:.2f}")

# --- SMOOTHED DATA DIAGNOSIS ---
# Apply a mild Gaussian filter (simulating denoising)
smoothed_frame = gaussian_filter(raw_frame, sigma=1)
smoothed_uxx = calculate_uxx(smoothed_frame)
smoothed_uxx_variance = np.var(smoothed_uxx)
print(f"SMOOTHED u_xx Variance: {smoothed_uxx_variance:.2f}")

# --- VISUAL PROOF ---
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Display the magnitude of the derivatives
axes[0].imshow(np.abs(raw_uxx), cmap='hot', vmax=np.percentile(np.abs(raw_uxx), 98))
axes[0].set_title(f'Raw |u_xx| (Variance: {raw_uxx_variance:.0f})')
axes[1].imshow(np.abs(smoothed_uxx), cmap='hot', vmax=np.percentile(np.abs(smoothed_uxx), 98))
axes[1].set_title(f'Smoothed |u_xx| (Variance: {smoothed_uxx_variance:.0f})')
plt.show()



In [None]:



# Recompute derivatives
u_t, u_x, u_y, u_xx, u_yy = compute_derivatives(
    images_clean_3d, 
    dx=1.0, dy=1.0, dt=1.0
)

print(f"\nNew u_t range: [{u_t.min():.3f}, {u_t.max():.3f}]")

In [None]:
def compute_derivatives(images_clean, dx=1.0, dy=1.0, dt=1.0):
    """
    Compute spatial and temporal derivatives for SINDy.
    
    Args:
        images_clean: (T, H, W) denoised image stack in [0, 255]
        dx, dy: spatial grid spacing (use 1.0 if unknown)
        dt: time step between frames (use 1.0 if unknown, can rescale later)
    
    Returns:
        u_t, u_x, u_y, u_xx, u_yy: derivative arrays
    """
    T, H, W = images_clean.shape
    
    # 1. Temporal derivative (∂u/∂t) - forward difference
    u_t = np.zeros_like(images_clean)
    u_t[:-1] = (images_clean[1:] - images_clean[:-1]) / dt
    # Note: last frame has u_t = 0 (can't compute forward difference)
    
    # 2. First spatial derivatives (∂u/∂x, ∂u/∂y) - central difference
    u_x = np.zeros_like(images_clean)
    u_y = np.zeros_like(images_clean)
    
    u_x[:, :, 1:-1] = (images_clean[:, :, 2:] - images_clean[:, :, :-2]) / (2*dx)
    u_y[:, 1:-1, :] = (images_clean[:, 2:, :] - images_clean[:, :-2, :]) / (2*dy)
    
    # 3. Second spatial derivatives (∂²u/∂x², ∂²u/∂y²) - central difference
    u_xx = np.zeros_like(images_clean)
    u_yy = np.zeros_like(images_clean)
    
    u_xx[:, :, 1:-1] = (images_clean[:, :, 2:] - 2*images_clean[:, :, 1:-1] + 
                         images_clean[:, :, :-2]) / (dx**2)
    u_yy[:, 1:-1, :] = (images_clean[:, 2:, :] - 2*images_clean[:, 1:-1, :] + 
                         images_clean[:, :-2, :]) / (dy**2)
    
    return u_t, u_x, u_y, u_xx, u_yy


# Compute derivatives
print("="*60)
print("COMPUTING DERIVATIVES")
print("="*60)

u_t, u_x, u_y, u_xx, u_yy = compute_derivatives(
    images_clean_3d, 
    dx=1.0,  # pixel spacing (or physical spacing if known)
    dy=1.0, 
    dt=1.0   # time step (or physical time if known)
)

print(f"\n✅ Derivatives computed successfully!")
print(f"\nShapes:")
print(f"  u_t:  {u_t.shape}")
print(f"  u_x:  {u_x.shape}")
print(f"  u_xx: {u_xx.shape}")

print(f"\nRanges (checking for reasonable values):")
print(f"  u_t:  [{u_t.min():.3f}, {u_t.max():.3f}]")
print(f"  u_x:  [{u_x.min():.3f}, {u_x.max():.3f}]")
print(f"  u_y:  [{u_y.min():.3f}, {u_y.max():.3f}]")
print(f"  u_xx: [{u_xx.min():.3f}, {u_xx.max():.3f}]")
print(f"  u_yy: [{u_yy.min():.3f}, {u_yy.max():.3f}]")

# Visual check of derivatives
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 3, figsize=(15, 8))

frame_idx = 25  # Check active frame

axes[0, 0].imshow(images_clean_3d[frame_idx], cmap='gray')
axes[0, 0].set_title('Original (Frame 25)')

axes[0, 1].imshow(u_t[frame_idx], cmap='RdBu', vmin=-20, vmax=20)
axes[0, 1].set_title('∂u/∂t (Temporal)')

axes[0, 2].imshow(u_x[frame_idx], cmap='RdBu', vmin=-10, vmax=10)
axes[0, 2].set_title('∂u/∂x (Spatial)')

axes[1, 0].imshow(u_y[frame_idx], cmap='RdBu', vmin=-10, vmax=10)
axes[1, 0].set_title('∂u/∂y (Spatial)')

axes[1, 1].imshow(u_xx[frame_idx], cmap='RdBu', vmin=-5, vmax=5)
axes[1, 1].set_title('∂²u/∂x² (Laplacian X)')

axes[1, 2].imshow(u_yy[frame_idx], cmap='RdBu', vmin=-5, vmax=5)
axes[1, 2].set_title('∂²u/∂y² (Laplacian Y)')

plt.tight_layout()
plt.savefig('derivatives_visualization.png', dpi=150)
plt.show()

print("\n✅ Derivative visualization saved to 'derivatives_visualization.png'")

The time derivatives are still noisy. 

In [None]:
type(images_clean_3d)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import convolve
from sklearn.metrics import r2_score

# ============================================================================
# ASSUME: images_more_clean is already loaded and preprocessed
# Shape: (n_time, height, width) with intensity values [0, 255]
# ============================================================================

def apply_sindy_to_real_images(images, dt=1.0, dx=1.0, dy=1.0, threshold=0.1, 
                               subsample_frames=None, subsample_spatial=4):
    """
    Apply SINDy and Weak-SINDy to real experimental images
    
    Parameters:
    -----------
    images : ndarray, shape (n_time, height, width)
        Preprocessed image sequence
    dt : float
        Time between frames (if unknown, use 1.0 and interpret results relatively)
    dx, dy : float
        Spatial resolution (if unknown, use 1.0 and interpret results in pixel units)
    threshold : float
        Sparsity threshold for SINDy
    subsample_frames : int or None
        Use only every Nth frame (None = use all frames)
    subsample_spatial : int
        Downsample spatially by this factor (e.g., 4 means use 1/16 of pixels)
    """
    
    print("=" * 70)
    print("APPLYING SINDY TO REAL IMAGES")
    print("=" * 70)
    
    n_frames, height, width = images.shape
    print(f"\nOriginal data shape: {n_frames} frames, {height}x{width} pixels")
    
    # Subsample frames if specified
    if subsample_frames is not None and subsample_frames > 1:
        images = images[::subsample_frames]
        dt = dt * subsample_frames
        print(f"Subsampling frames: using every {subsample_frames}th frame")
    
    # Subsample spatially to reduce memory
    if subsample_spatial > 1:
        images = images[:, ::subsample_spatial, ::subsample_spatial]
        dx = dx * subsample_spatial
        dy = dy * subsample_spatial
        print(f"Downsampling spatially by factor {subsample_spatial}")
    
    n_frames, height, width = images.shape
    total_samples = n_frames * height * width
    memory_estimate = total_samples * 9 * 8 / (1024**3)  # 9 features, 8 bytes per float
    
    print(f"\nProcessing data shape: {n_frames} frames, {height}x{width} pixels")
    print(f"Total samples: {total_samples:,}")
    print(f"Estimated memory for library: {memory_estimate:.2f} GB")
    
    if memory_estimate > 8:
        print("\n⚠️  WARNING: This will require a lot of memory!")
        print("   Consider increasing subsample_spatial or subsample_frames")
        response = input("   Continue anyway? (y/n): ")
        if response.lower() != 'y':
            return None
    
    print(f"Time step: {dt}")
    print(f"Spatial resolution: dx={dx}, dy={dy}")
    
    # Normalize to [0, 1] for numerical stability
    u_max = images.max()
    u_min = images.min()
    u_data = (images - u_min) / (u_max - u_min)
    print(f"\nNormalized data to [0, 1] range")
    print(f"Original range: [{u_min:.2f}, {u_max:.2f}]")
    
    # --- Spatial Derivative Functions ---
    def get_laplacian(f):
        """Compute Laplacian using finite differences"""
        return (np.roll(f, -1, axis=0) + np.roll(f, 1, axis=0) +
                np.roll(f, -1, axis=1) + np.roll(f, 1, axis=1) - 4*f) / (dx**2)
    
    def get_gradients(f):
        """Compute gradients using central differences"""
        gx = (np.roll(f, -1, axis=0) - np.roll(f, 1, axis=0)) / (2*dx)
        gy = (np.roll(f, -1, axis=1) - np.roll(f, 1, axis=1)) / (2*dy)
        return gx, gy
    
    # --- Build Library ---
    def build_library(u_data):
        """Build library of candidate functions"""
        n_frames, nx, ny = u_data.shape
        n_samples = n_frames * nx * ny
        
        u_flat = u_data.reshape(n_samples)
        
        lib_features = []
        feature_names = []
        
        print("\nBuilding function library...")
        
        # Polynomial terms
        lib_features.append(np.ones(n_samples))
        feature_names.append('1')
        
        lib_features.append(u_flat)
        feature_names.append('u')
        
        lib_features.append(u_flat**2)
        feature_names.append('u²')
        
        lib_features.append(u_flat**3)
        feature_names.append('u³')
        
        # Compute spatial derivatives for all frames
        lap_all = np.zeros_like(u_data)
        bilap_all = np.zeros_like(u_data)
        gx_all = np.zeros_like(u_data)
        gy_all = np.zeros_like(u_data)
        
        for i in range(n_frames):
            lap_all[i] = get_laplacian(u_data[i])
            bilap_all[i] = get_laplacian(lap_all[i])
            gx_all[i], gy_all[i] = get_gradients(u_data[i])
        
        # Spatial derivative terms
        lib_features.append(lap_all.reshape(n_samples))
        feature_names.append('∇²u')
        
        lib_features.append(bilap_all.reshape(n_samples))
        feature_names.append('∇⁴u')
        
        lib_features.append((gx_all**2 + gy_all**2).reshape(n_samples))
        feature_names.append('|∇u|²')
        
        lib_features.append((u_data * lap_all).reshape(n_samples))
        feature_names.append('u∇²u')
        
        # Additional terms that might appear in reaction-diffusion
        lib_features.append((u_flat * (1 - u_flat)).reshape(-1))
        feature_names.append('u(1-u)')
        
        Theta = np.column_stack(lib_features)
        
        print(f"Library built: {len(feature_names)} candidate terms")
        print(f"Library matrix shape: {Theta.shape}")
        
        return Theta, feature_names
    
    # --- Standard SINDy ---
    def sindy_method(u_data, threshold):
        """Standard SINDy"""
        print("\n" + "-" * 70)
        print("STANDARD SINDY")
        print("-" * 70)
        
        # Compute time derivative
        u_dot = np.zeros_like(u_data[:-1])
        for i in range(len(u_data) - 1):
            u_dot[i] = (u_data[i+1] - u_data[i]) / dt
        
        # Build library (exclude last frame)
        Theta, feature_names = build_library(u_data[:-1])
        u_dot_flat = u_dot.reshape(-1)
        
        # Sequential Thresholded Least Squares
        coeffs = np.linalg.lstsq(Theta, u_dot_flat, rcond=None)[0]
        
        for iteration in range(10):
            small_inds = np.abs(coeffs) < threshold
            coeffs[small_inds] = 0
            big_inds = ~small_inds
            if np.sum(big_inds) == 0:
                break
            coeffs[big_inds] = np.linalg.lstsq(Theta[:, big_inds], u_dot_flat, rcond=None)[0]
        
        # Compute metrics
        prediction = Theta @ coeffs
        mse = np.mean((prediction - u_dot_flat)**2)
        r2 = r2_score(u_dot_flat, prediction)
        
        print("\nDiscovered equation: ∂u/∂t = ")
        for i, (coef, name) in enumerate(zip(coeffs, feature_names)):
            if abs(coef) > 1e-10:
                print(f"  {coef:+.6f} · {name}")
        
        print(f"\nMean Squared Error: {mse:.6e}")
        print(f"R² Score: {r2:.6f}")
        
        return coeffs, feature_names, mse, r2, prediction, u_dot_flat
    
    # --- Weak SINDy ---
    def weak_sindy_method(u_data, threshold, kernel_size=7):
        """Weak SINDy with test functions"""
        print("\n" + "-" * 70)
        print("WEAK SINDY")
        print("-" * 70)
        
        n_frames, nx, ny = u_data.shape
        
        # Create Gaussian test function
        def gaussian_kernel_2d(size, sigma):
            kernel = np.zeros((size, size))
            center = size // 2
            for i in range(size):
                for j in range(size):
                    x, y = i - center, j - center
                    kernel[i, j] = np.exp(-(x**2 + y**2) / (2 * sigma**2))
            return kernel / kernel.sum()
        
        test_func = gaussian_kernel_2d(kernel_size, kernel_size/4)
        print(f"Using Gaussian kernel of size {kernel_size}x{kernel_size}")
        
        # Apply weak formulation
        u_weak = np.zeros((n_frames, nx, ny))
        for i in range(n_frames):
            u_weak[i] = convolve(u_data[i], test_func, mode='reflect')
        
        # Time derivative in weak form
        u_dot_weak = np.zeros_like(u_weak[:-1])
        for i in range(len(u_weak) - 1):
            u_dot_weak[i] = (u_weak[i+1] - u_weak[i]) / dt
        
        # Build weak library
        Theta_weak = []
        feature_names = ['1', 'u', 'u²', 'u³', '∇²u', '∇⁴u', '|∇u|²', 'u∇²u', 'u(1-u)']
        
        for frame_idx in range(n_frames - 1):
            frame_features = []
            
            frame_features.append(np.sum(test_func) * np.ones((nx, ny)))
            frame_features.append(u_weak[frame_idx])
            frame_features.append(convolve(u_data[frame_idx]**2, test_func, mode='reflect'))
            frame_features.append(convolve(u_data[frame_idx]**3, test_func, mode='reflect'))
            
            lap = get_laplacian(u_data[frame_idx])
            bilap = get_laplacian(lap)
            gx, gy = get_gradients(u_data[frame_idx])
            
            frame_features.append(convolve(lap, test_func, mode='reflect'))
            frame_features.append(convolve(bilap, test_func, mode='reflect'))
            frame_features.append(convolve(gx**2 + gy**2, test_func, mode='reflect'))
            frame_features.append(convolve(u_data[frame_idx] * lap, test_func, mode='reflect'))
            frame_features.append(convolve(u_data[frame_idx] * (1 - u_data[frame_idx]), test_func, mode='reflect'))
            
            Theta_weak.append(np.column_stack([f.reshape(-1) for f in frame_features]))
        
        Theta_weak = np.vstack(Theta_weak)
        u_dot_weak_flat = u_dot_weak.reshape(-1)
        
        # Sequential Thresholded Least Squares
        coeffs = np.linalg.lstsq(Theta_weak, u_dot_weak_flat, rcond=None)[0]
        
        for iteration in range(10):
            small_inds = np.abs(coeffs) < threshold
            coeffs[small_inds] = 0
            big_inds = ~small_inds
            if np.sum(big_inds) == 0:
                break
            coeffs[big_inds] = np.linalg.lstsq(Theta_weak[:, big_inds], u_dot_weak_flat, rcond=None)[0]
        
        # Compute metrics
        prediction = Theta_weak @ coeffs
        mse = np.mean((prediction - u_dot_weak_flat)**2)
        r2 = r2_score(u_dot_weak_flat, prediction)
        
        print("\nDiscovered equation: ∂u/∂t = ")
        for i, (coef, name) in enumerate(zip(coeffs, feature_names)):
            if abs(coef) > 1e-10:
                print(f"  {coef:+.6f} · {name}")
        
        print(f"\nMean Squared Error: {mse:.6e}")
        print(f"R² Score: {r2:.6f}")
        
        return coeffs, feature_names, mse, r2, prediction, u_dot_weak_flat
    
    # --- Run Both Methods ---
    sindy_coeffs, sindy_names, sindy_mse, sindy_r2, sindy_pred, sindy_target = sindy_method(u_data, threshold)
    weak_coeffs, weak_names, weak_mse, weak_r2, weak_pred, weak_target = weak_sindy_method(u_data, threshold)
    
    # --- Verification Strategy ---
    print("\n" + "=" * 70)
    print("VERIFICATION STRATEGY")
    print("=" * 70)
    
    print("\n1. RECONSTRUCTION QUALITY")
    print("   - R² Score measures how well discovered equation fits data")
    print(f"   - SINDy R²: {sindy_r2:.4f}")
    print(f"   - Weak-SINDy R²: {weak_r2:.4f}")
    print("   - Values close to 1.0 indicate good fit")
    
    print("\n2. CONSISTENCY CHECK")
    print("   - Do both methods identify similar dominant terms?")
    sindy_active = set([name for coef, name in zip(sindy_coeffs, sindy_names) if abs(coef) > 1e-6])
    weak_active = set([name for coef, name in zip(weak_coeffs, weak_names) if abs(coef) > 1e-6])
    common_terms = sindy_active & weak_active
    print(f"   - Common active terms: {common_terms}")
    
    print("\n3. PHYSICAL PLAUSIBILITY")
    print("   - Check if discovered terms make physical sense")
    print("   - Common PDE patterns:")
    print("     • Diffusion: ∇²u term")
    print("     • Reaction-Diffusion: ∇²u + u(1-u) or ∇²u + u²")
    print("     • Pattern formation: ∇²u - ∇⁴u (Swift-Hohenberg)")
    print("     • Nonlinear advection: |∇u|² or u∇²u")
    
    print("\n4. FORWARD SIMULATION")
    print("   - Use discovered PDE to simulate forward in time")
    print("   - Compare simulation with actual images")
    print("   - This is implemented in the visualization below")
    
    # --- Visualization ---
    fig = plt.figure(figsize=(18, 12))
    
    # 1. Coefficient Comparison
    ax1 = plt.subplot(3, 3, 1)
    x_pos = np.arange(len(sindy_names))
    width = 0.35
    ax1.bar(x_pos - width/2, sindy_coeffs, width, label='SINDy', alpha=0.8)
    ax1.bar(x_pos + width/2, [weak_coeffs[weak_names.index(n)] if n in weak_names else 0 for n in sindy_names], 
            width, label='Weak-SINDy', alpha=0.8)
    ax1.set_xlabel('Terms')
    ax1.set_ylabel('Coefficient Value')
    ax1.set_title('Discovered Coefficients', fontweight='bold')
    ax1.set_xticks(x_pos)
    ax1.set_xticklabels(sindy_names, rotation=45, ha='right')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    ax1.axhline(y=0, color='k', linestyle='-', linewidth=0.5)
    
    # 2. R² Comparison
    ax2 = plt.subplot(3, 3, 2)
    methods = ['SINDy', 'Weak-SINDy']
    r2_scores = [sindy_r2, weak_r2]
    colors = ['#1f77b4', '#ff7f0e']
    bars = ax2.bar(methods, r2_scores, color=colors, alpha=0.8, edgecolor='black', linewidth=2)
    ax2.set_ylabel('R² Score')
    ax2.set_title('Reconstruction Quality', fontweight='bold')
    ax2.set_ylim([0, 1.1])
    ax2.axhline(y=1.0, color='red', linestyle='--', linewidth=1, alpha=0.5)
    ax2.grid(True, alpha=0.3, axis='y')
    for bar in bars:
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.4f}', ha='center', va='bottom', fontsize=11, fontweight='bold')
    
    # 3. MSE Comparison
    ax3 = plt.subplot(3, 3, 3)
    mse_scores = [sindy_mse, weak_mse]
    bars = ax3.bar(methods, mse_scores, color=colors, alpha=0.8, edgecolor='black', linewidth=2)
    ax3.set_ylabel('Mean Squared Error')
    ax3.set_title('Prediction Error', fontweight='bold')
    ax3.set_yscale('log')
    ax3.grid(True, alpha=0.3, axis='y')
    for bar in bars:
        height = bar.get_height()
        ax3.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.2e}', ha='center', va='bottom', fontsize=10)
    
    # 4-6. Sample frames from data
    sample_frames = [0, n_frames//2, n_frames-1]
    for i, frame_idx in enumerate(sample_frames):
        ax = plt.subplot(3, 3, 4+i)
        im = ax.imshow(u_data[frame_idx], cmap='viridis', aspect='auto')
        ax.set_title(f'Frame {frame_idx}', fontweight='bold')
        ax.axis('off')
        plt.colorbar(im, ax=ax, fraction=0.046)
    
    # 7-8. Prediction vs Target scatter plots
    # Subsample for visualization
    n_subsample = min(10000, len(sindy_target))
    indices = np.random.choice(len(sindy_target), n_subsample, replace=False)
    
    ax7 = plt.subplot(3, 3, 7)
    ax7.scatter(sindy_target[indices], sindy_pred[indices], alpha=0.3, s=1)
    ax7.plot([sindy_target.min(), sindy_target.max()], 
             [sindy_target.min(), sindy_target.max()], 'r--', linewidth=2)
    ax7.set_xlabel('True ∂u/∂t')
    ax7.set_ylabel('Predicted ∂u/∂t')
    ax7.set_title(f'SINDy: R²={sindy_r2:.4f}', fontweight='bold')
    ax7.grid(True, alpha=0.3)
    
    ax8 = plt.subplot(3, 3, 8)
    ax8.scatter(weak_target[indices], weak_pred[indices], alpha=0.3, s=1, color='orange')
    ax8.plot([weak_target.min(), weak_target.max()], 
             [weak_target.min(), weak_target.max()], 'r--', linewidth=2)
    ax8.set_xlabel('True ∂u/∂t')
    ax8.set_ylabel('Predicted ∂u/∂t')
    ax8.set_title(f'Weak-SINDy: R²={weak_r2:.4f}', fontweight='bold')
    ax8.grid(True, alpha=0.3)
    
    # 9. Temporal evolution of mean intensity
    ax9 = plt.subplot(3, 3, 9)
    mean_intensity = [u_data[i].mean() for i in range(n_frames)]
    ax9.plot(mean_intensity, linewidth=2)
    ax9.set_xlabel('Frame')
    ax9.set_ylabel('Mean Intensity')
    ax9.set_title('Temporal Evolution', fontweight='bold')
    ax9.grid(True, alpha=0.3)
    
    plt.suptitle('SINDy Analysis of Real Experimental Images', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    return {
        'sindy': {'coeffs': sindy_coeffs, 'names': sindy_names, 'mse': sindy_mse, 'r2': sindy_r2},
        'weak': {'coeffs': weak_coeffs, 'names': weak_names, 'mse': weak_mse, 'r2': weak_r2},
        'u_data': u_data
    }

# ============================================================================
# USAGE:
results = apply_sindy_to_real_images(images_more_clean, dt=1.0, dx=1.0, threshold=0.1)
# ============================================================================



In [None]:
results = apply_sindy_to_real_images(images_clean_3d, dt=1.0, dx=1.0, threshold=0.1)


In [None]:
results = apply_sindy_to_real_images(images, dt=1.0, dx=1.0, threshold=0.1)
