In [1]:
import os
from pathlib import Path
from typing import Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from omegaconf import OmegaConf
import masknmf
from typing import *
from spatial_denoiser import train_spatial_denoiser
from spatial_denoiser import create_pmd_denoiser

No windowing system present. Using surfaceless platform
No config found!
No config found!
Max vertex attribute stride unknown. Assuming it is 2048
Max vertex attribute stride unknown. Assuming it is 2048
Max vertex attribute stride unknown. Assuming it is 2048


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01,\x00\x00\x007\x08\x06\x00\x00\x00\xb6\x1bw\x99\x‚Ä¶

Valid,Device,Type,Backend,Driver
‚ùå (default),NVIDIA A100-PCIE-40GB/PCIe/SSE2,Unknown,OpenGL,3.3.0 NVIDIA 565.57.01


Max vertex attribute stride unknown. Assuming it is 2048
Max vertex attribute stride unknown. Assuming it is 2048


In [2]:
class MotionBinDataset:
    """Load a memmapped suite2p data.bin together with metadata (.npy/.zip)."""
    def __init__(self, data_path: str, metadata_path: str, dtype=np.int16):
        self.bin_path = Path(data_path)
        self.ops_path = Path(metadata_path)
        self._dtype = dtype
        self._shape = self._compute_shape()
        self.data = np.memmap(self.bin_path, mode='r', dtype=self._dtype, shape=self._shape)

    def _compute_shape(self) -> Tuple[int, int, int]:
        _, ext = os.path.splitext(self.ops_path)
        if ext == ".zip":
            ops = np.load(self.ops_path, allow_pickle=True)['ops'].item()
        elif ext == ".npy":
            ops = np.load(self.ops_path, allow_pickle=True).item()
        else:
            raise ValueError("Metadata file must be .zip or .npy")
        return int(ops['nframes']), int(ops['Ly']), int(ops['Lx'])

    @property
    def shape(self) -> Tuple[int, int, int]:
        return self._shape

    def __getitem__(self, item):
        return self.data[item].copy()

config = {
    'bin_file_path': '/burg-archive/home/lm3879/plane4/data.bin',
    'ops_file_path': '/burg-archive/home/lm3879/plane4/ops.npy',
    'out_path': '/burg-archive/home/lm3879/masknmf-toolbox/ibl_denoised_output/pmd_spatial_results.npz',
    'block_size_dim1': 32,
    'block_size_dim2': 32,
    'max_components': 20,
    'max_consecutive_failures': 1,
    'spatial_avg_factor': 1,
    'temporal_avg_factor': 1,
    'device': 'cuda',
    'frame_batch_size': 1024,
    'denoiser_max_epochs': 5,
    'denoiser_batch_size': 32,
    'denoiser_lr': 1e-4,
    'patch_h': 40,
    'patch_w': 40,
}

cfg = OmegaConf.create(config)
device = cfg.device if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

my_data = MotionBinDataset(cfg.bin_file_path, cfg.ops_file_path)
print(f"Loaded data with shape (T,H,W): {my_data.shape}")
binary_mask = np.zeros((my_data.shape[1], my_data.shape[2]), dtype=np.float32)
binary_mask[3:-3, 3:-3] = 1.0

pmd_no_denoise = masknmf.compression.pmd_decomposition(
    my_data,
    [cfg.block_size_dim1, cfg.block_size_dim2],
    my_data.shape[0],
    max_components=cfg.max_components,
    max_consecutive_failures=cfg.max_consecutive_failures,
    temporal_avg_factor=cfg.temporal_avg_factor,
    spatial_avg_factor=cfg.spatial_avg_factor,
    device=cfg.device,
    temporal_denoiser=None,
    frame_batch_size=cfg.frame_batch_size,
    pixel_weighting=binary_mask
)
print(f"PMD rank: {pmd_no_denoise.pmd_rank}")
u_dense = pmd_no_denoise.u.to_dense()
H, W = my_data.shape[1], my_data.shape[2]
u_reshaped = u_dense.reshape(H, W, -1)
print(f"Spatial components shape: {u_reshaped.shape}")
print(f"Number of components: {u_reshaped.shape[2]}")
spatial_components = u_reshaped.permute(2, 0, 1)  # (rank, H, W)

Using device: cuda
Loaded data with shape (T,H,W): (11700, 500, 620)
[25-12-11 05:35:26]: Starting compression
[25-12-11 05:35:26]: sampled from the following regions: [0]
[25-12-11 05:35:26]: We are initializing on a total of 11700 frames
[25-12-11 05:35:35]: Loading data to estimate complete spatial basis
[25-12-11 05:35:35]: skipping the pruning step for frame cutoff
[25-12-11 05:35:35]: Finding spatiotemporal roughness thresholds


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 250/250 [00:01<00:00, 234.83it/s]

[25-12-11 05:35:36]: Running Blockwise Decompositions





[25-12-11 05:36:25]: Constructed U matrix. Rank of U is 2563
[25-12-11 05:36:25]: PMD Objected constructed
PMD rank: 2563
Spatial components shape: torch.Size([500, 620, 2563])
Number of components: 2563


In [3]:
spatial_model, _ = train_spatial_denoiser(
    spatial_components,
    config={'max_epochs': 5},
    output_dir='./denoiser_output'
)

spatial_denoiser = create_pmd_denoiser(
    trained_model=spatial_model,
    noise_variance_quantile=0.7,
    padding=12
)


Starting Spatial Denoiser Training
Device: cuda
GPU: NVIDIA A100-PCIE-40GB
Total memory: 39.50 GB

Extracting valid patches from spatial components...
  Processed 500/2563 components...
  Processed 1000/2563 components...
  Processed 1500/2563 components...
  Processed 2000/2563 components...
  Processed 2500/2563 components...

Extracted 2563 valid patches from 2563 components


üìä Patch Statistics:
  Number of patches: 2563
  Size range: 34 - 36 pixels
  Average size: 1293 pixels


Using 16bit Automatic Mixed Precision (AMP)
üí° Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
You are using a CUDA device ('NVIDIA A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]
/burg-archive/home/lm3879/miniconda3/envs/py311v2/lib/python3.11/site-packages/pytorch_lightning/utilities/model_summary/model_summary.py:231: Precision 16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits

Dataset created: 2563 patches
Memory footprint: 0.01 GB
Using 2563 random patches for training
DataLoader ready with 81 batches

Starting training...
Effective batch size: 64


/burg-archive/home/lm3879/miniconda3/envs/py311v2/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


Training complete!
Peak GPU memory: 0.48 GB
Model saved to: denoiser_output/spatial_denoiser_state_dict.pth

Blindspot leakage test: 0.00e+00
‚úì Blindspot property verified

‚úì PMD spatial denoiser created
  Noise variance quantile: 0.7
  Padding: 12
  Device: cuda


In [4]:
v = pmd_no_denoise.v.cpu()
temporal_model, _ = masknmf.compression.denoising.train_total_variance_denoiser(v,
                                                                        max_epochs=5,
                                                                        batch_size=128,
                                                                        learning_rate=1e-4)


temporal_denoiser = masknmf.compression.PMDTemporalDenoiser(temporal_model, 0.7)

Using 16bit Automatic Mixed Precision (AMP)
üí° Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
/burg-archive/home/lm3879/miniconda3/envs/py311v2/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:751: Checkpoint directory /burg-archive/home/lm3879/masknmf-toolbox/masknmf/compression/lightning_logs/version_5252547/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]

  | Name             | Type            | Params | Mode 
-------------------------------------------------------------
0 | temporal_network | TemporalNetwork | 278 K  | train
-------------------------------------------------------------
170 K     Trainable params
107 K     Non-trainable params
278 K     Total params
1.114     Total estimated

Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


In [5]:
pmd_with_denoise_S_T = masknmf.compression.pmd_decomposition(
    my_data,
    [cfg.block_size_dim1, cfg.block_size_dim2],
    my_data.shape[0],
    max_components=cfg.max_components,
    max_consecutive_failures=cfg.max_consecutive_failures,
    temporal_avg_factor=cfg.temporal_avg_factor,
    spatial_avg_factor=cfg.spatial_avg_factor,
    device=cfg.device,
    spatial_denoiser=spatial_denoiser,
    temporal_denoiser=temporal_denoiser,
    frame_batch_size=cfg.frame_batch_size,
    pixel_weighting=binary_mask
)

print(f"\n{'='*60}")
print("PMD Results:")
print(f"  With spatial & temporal denoiser: rank = {pmd_with_denoise_S_T.pmd_rank}")
print(f"{'='*60}")

[25-12-11 05:37:39]: Starting compression
[25-12-11 05:37:39]: sampled from the following regions: [0]
[25-12-11 05:37:39]: We are initializing on a total of 11700 frames
[25-12-11 05:37:45]: Loading data to estimate complete spatial basis
[25-12-11 05:37:45]: skipping the pruning step for frame cutoff
[25-12-11 05:37:45]: Finding spatiotemporal roughness thresholds


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 250/250 [00:05<00:00, 45.61it/s]

[25-12-11 05:37:51]: Running Blockwise Decompositions





[25-12-11 05:41:00]: Constructed U matrix. Rank of U is 3783
[25-12-11 05:41:00]: PMD Objected constructed

PMD Results:
  With spatial & temporal denoiser: rank = 3783


In [None]:
resid_arr_Spatial_Temporal = masknmf.compression.PMDResidualArray(my_data, pmd_with_denoise_S_T)
import fastplotlib as fpl
iw = fpl.ImageWidget(
    data=[my_data, pmd_with_denoise_S_T, resid_arr_Spatial_Temporal], 
    names=['motion corrected', 'spatial and temporal denoised', 'residual']
)
iw.show()

In [None]:
u_dense_before = pmd_no_denoise.u.to_dense().cpu()     # (H*W, R)
u_dense_after  = pmd_with_denoise_S_T.u.to_dense().cpu() # (H*W, R)
H, W = my_data.shape[1], my_data.shape[2]
u_before = u_dense_before.reshape(H, W, -1)    # (H, W, R)
u_after  = u_dense_after.reshape(H, W, -1)

import matplotlib.pyplot as plt
from scipy.ndimage import label

# --------------------- Adjustable Parameters ---------------------
topK_display = 12       # Number of before components to display
topN_before = 400       # Number of high-energy columns selected from "before" for matching
topM_after  = 400       # Number of high-energy columns selected from "after" for matching
chunk_cols = 100        # Chunk size when computing correlation matrix
corr_threshold = 0.0    # Matching threshold
cmap = 'gray'
zoom_padding = 2        # Extra padding around detected patch
components_per_image = 6  # Number of components per long image
# -----------------------------------------------------

def ensure_numpy(x):
    if isinstance(x, torch.Tensor):
        return x.detach().cpu().numpy()
    return np.asarray(x)

def l2_norms_of_components(u):
    """Compute L2 norm for each component"""
    R = u.shape[2]
    return np.array([np.linalg.norm(u[:,:,i]) for i in range(R)])

def detect_nonzero_bbox(img, threshold_quantile=0.01):
    """
    Detect bounding box of non-zero regions in an image
    Returns (row_min, row_max, col_min, col_max)
    """
    img_abs = np.abs(img)
    
    # Use quantile threshold to determine "non-zero" region (noise removal)
    threshold = np.quantile(img_abs[img_abs > 0], threshold_quantile) if np.any(img_abs > 0) else 0
    mask = img_abs > threshold
    
    if not np.any(mask):
        return 0, img.shape[0], 0, img.shape[1]
    
    # Find bounding rows and columns
    rows = np.any(mask, axis=1)
    cols = np.any(mask, axis=0)
    
    row_indices = np.where(rows)[0]
    col_indices = np.where(cols)[0]
    
    if len(row_indices) == 0 or len(col_indices) == 0:
        return 0, img.shape[0], 0, img.shape[1]
    
    row_min, row_max = row_indices[0], row_indices[-1] + 1
    col_min, col_max = col_indices[0], col_indices[-1] + 1
    
    return row_min, row_max, col_min, col_max

def get_union_bbox(img1, img2, padding=2):
    """
    Get union bounding box of two images' non-zero regions
    Ensures before and after use the same cropping region
    """
    bbox1 = detect_nonzero_bbox(img1)
    bbox2 = detect_nonzero_bbox(img2)
    
    # Union of boxes
    row_min = min(bbox1[0], bbox2[0])
    row_max = max(bbox1[1], bbox2[1])
    col_min = min(bbox1[2], bbox2[2])
    col_max = max(bbox1[3], bbox2[3])
    
    # Add padding
    H, W = img1.shape
    row_min = max(0, row_min - padding)
    row_max = min(H, row_max + padding)
    col_min = max(0, col_min - padding)
    col_max = min(W, col_max + padding)
    
    return row_min, row_max, col_min, col_max

def prepare_top_subset(u_before, u_after, topN_before, topM_after):
    """Select subset from u_before/u_after based on energy for matching"""
    H, W, Rb = u_before.shape
    _, _, Ra = u_after.shape
    M = H * W

    ener_b = l2_norms_of_components(u_before)
    ener_a = l2_norms_of_components(u_after)
    
    nb = min(topN_before, Rb)
    na = min(topM_after, Ra)

    idx_b = np.argsort(ener_b)[-nb:][::-1]
    idx_a = np.argsort(ener_a)[-na:][::-1]

    B = u_before.reshape(M, Rb)[:, idx_b]
    A = u_after.reshape(M, Ra)[:, idx_a]

    return B, A, idx_b, idx_a, ener_b, ener_a

def normalize_cols_float32(X):
    """Normalize column vectors"""
    Xf = X.astype(np.float32, copy=False)
    norms = np.linalg.norm(Xf, axis=0).astype(np.float32)
    nz = norms > 0
    if np.any(nz):
        Xf[:, nz] /= norms[nz]
    return Xf, norms

def compute_corr_blockwise(Bn, An, chunk_cols=100):
    """Compute correlation in blocks"""
    nb = Bn.shape[1]
    na = An.shape[1]
    corr = np.empty((nb, na), dtype=np.float32)
    for start in range(0, na, chunk_cols):
        end = min(start + chunk_cols, na)
        corr[:, start:end] = np.dot(Bn.T, An[:, start:end])
    return np.abs(corr)

def greedy_match_topK_for_befores(corr_abs, top_before_indices_in_subset=None):
    """Greedy matching"""
    nb, na = corr_abs.shape
    if top_before_indices_in_subset is None:
        rows_to_process = range(nb)
    else:
        rows_to_process = list(top_before_indices_in_subset)

    matches = []
    for r in rows_to_process:
        col = int(np.argmax(corr_abs[r, :]))
        val = float(corr_abs[r, col])
        matches.append((r, col, val))
    matches = sorted(matches, key=lambda x: x[2], reverse=True)
    return matches

def plot_component_grid(matches_data, start_idx, end_idx, cmap='gray', save_path=None):
    """
    Plot a component comparison grid:
    Each row: Before | After | Difference

    Parameters
    ----------
    matches_data: list of tuples (before_cropped, after_cropped, diff, idx_before, idx_after, corr_val, bbox)
    start_idx, end_idx: range of matches to display
    """
    num_show = end_idx - start_idx
    if num_show <= 0:
        return
    
    fig, axes = plt.subplots(num_show, 3, figsize=(12, 3*num_show))
    
    # Handle single row case
    if num_show == 1:
        axes = axes.reshape(1, -1)
    
    for i, idx in enumerate(range(start_idx, end_idx)):
        if idx >= len(matches_data):
            break
            
        b_crop, a_crop, diff, idx_b, idx_a, corr, bbox = matches_data[idx]
        row_min, row_max, col_min, col_max = bbox
        
        # Normalize color range
        vmax = max(np.abs(b_crop).max(), np.abs(a_crop).max())
        vmin = -vmax
        vmax_diff = np.max(np.abs(diff)) if diff.size > 0 else 1
        
        # Before
        im0 = axes[i, 0].imshow(b_crop, cmap=cmap, vmin=vmin, vmax=vmax)
        axes[i, 0].axis('off')
        axes[i, 0].set_title(f'Before #{idx_b}\nBBox:[{row_min}:{row_max},{col_min}:{col_max}]', 
                            fontsize=9)
        plt.colorbar(im0, ax=axes[i, 0], fraction=0.046, pad=0.04)
        
        # After
        im1 = axes[i, 1].imshow(a_crop, cmap=cmap, vmin=vmin, vmax=vmax)
        axes[i, 1].axis('off')
        axes[i, 1].set_title(f'After #{idx_a}\nCorr={corr:.3f}', fontsize=9)
        plt.colorbar(im1, ax=axes[i, 1], fraction=0.046, pad=0.04)
        
        # Difference
        im2 = axes[i, 2].imshow(diff, cmap='seismic', vmin=-vmax_diff, vmax=vmax_diff)
        axes[i, 2].axis('off')
        mse = np.mean(diff**2)
        axes[i, 2].set_title(f'Diff (After-Before)\nMSE={mse:.2e}', fontsize=9)
        plt.colorbar(im2, ax=axes[i, 2], fraction=0.046, pad=0.04)
    
    plt.tight_layout()
    
    if save_path:
        fig.savefig(save_path, dpi=200, bbox_inches='tight')
        print(f"Saved: {save_path}")
    
    plt.show()
    plt.close(fig)


# ---------------------------- Main Pipeline ----------------------------
print("="*60)
print("STEP 1: PREPARING DATA AND MATCHING")
print("="*60)

u_before = ensure_numpy(u_before)
u_after  = ensure_numpy(u_after)

if u_before.ndim != 3 or u_after.ndim != 3:
    raise ValueError("u_before/u_after must have shape (H, W, R)")

H, W, Rb = u_before.shape
H2, W2, Ra = u_after.shape
assert (H, W) == (H2, W2), "Spatial sizes must match"

print(f"Original ranks: before Rb={Rb}, after Ra={Ra}")

# Prepare matching subset
B, A, idx_b, idx_a, ener_b, ener_a = prepare_top_subset(
    u_before, u_after, topN_before, topM_after
)
M = H * W
print(f"Selected subsets for matching: before nb={B.shape[1]}, after na={A.shape[1]} (M={M})")

# Normalize
Bn, norms_b = normalize_cols_float32(B)
An, norms_a = normalize_cols_float32(A)

# Compute correlation
est_bytes = Bn.shape[1] * An.shape[1] * 4
print(f"Estimated correlation matrix size: {est_bytes/1e6:.1f} MB. chunk_cols={chunk_cols}")

corr_abs = compute_corr_blockwise(Bn, An, chunk_cols=chunk_cols)
print("Correlation computed.")

# Matching
topK_display_actual = min(topK_display, B.shape[1])
rows_to_match = list(range(topK_display_actual))
matches = greedy_match_topK_for_befores(corr_abs, top_before_indices_in_subset=rows_to_match)

# Map back to original indices
matches_mapped = []
for r_sub, c_sub, val in matches:
    orig_b = int(idx_b[r_sub])
    orig_a = int(idx_a[c_sub])
    matches_mapped.append((orig_b, orig_a, val))

print("\nMatching results:")
for b, a, v in matches_mapped[:topK_display_actual]:
    print(f"  Before #{b:4d} -> After #{a:4d}   Corr={v:.4f}")

# ---------------------------- Prepare Cropped Data ----------------------------
print("\n" + "="*60)
print("STEP 2: CROPPING COMPONENTS TO PATCHES")
print("="*60)

matches_data = []
for orig_b, orig_a, val in matches_mapped[:topK_display_actual]:
    if val < corr_threshold:
        continue
    
    before_img = u_before[:, :, orig_b]
    after_img = u_after[:, :, orig_a]
    
    # Align sign
    if np.sum(before_img.ravel() * after_img.ravel()) < 0:
        after_img = -after_img
    
    # Get union bounding box
    bbox = get_union_bbox(before_img, after_img, padding=zoom_padding)
    row_min, row_max, col_min, col_max = bbox
    
    # Crop
    b_crop = before_img[row_min:row_max, col_min:col_max]
    a_crop = after_img[row_min:row_max, col_min:col_max]
    diff = a_crop - b_crop
    
    matches_data.append((b_crop, a_crop, diff, orig_b, orig_a, val, bbox))
    print(f"  Component {len(matches_data)}: Before #{orig_b} -> After #{orig_a}, "
          f"Cropped size: {b_crop.shape}, BBox: [{row_min}:{row_max}, {col_min}:{col_max}]")

# ---------------------------- Create Long Images ----------------------------
print("\n" + "="*60)
print("STEP 3: CREATING COMPONENT COMPARISON GRID IMAGES")
print("="*60)

num_images = int(np.ceil(len(matches_data) / components_per_image))

for img_idx in range(num_images):
    start_idx = img_idx * components_per_image
    end_idx = min((img_idx + 1) * components_per_image, len(matches_data))
    
    print(f"\nCreating image {img_idx+1}/{num_images} (components {start_idx+1}-{end_idx})...")
    
    save_path = f'component_grid_{img_idx+1}_of_{num_images}.png'
    plot_component_grid(matches_data, start_idx, end_idx, cmap=cmap, save_path=save_path)


print("\nAll visualizations complete!")
print(f"\nGenerated {num_images} component grid image(s)")
print("Files:")
for i in range(num_images):
    print(f"  - component_grid_{i+1}_of_{num_images}.png")
