In [1]:
import os
from pathlib import Path
from typing import Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from omegaconf import OmegaConf
import masknmf
from typing import *
from masknmf.compression.compression_strategies import CompressSpatialDenoiseStrategy, CompressSpatialTemporalDenoiseStrategy

No windowing system present. Using surfaceless platform
No config found!
No config found!
Max vertex attribute stride unknown. Assuming it is 2048
Max vertex attribute stride unknown. Assuming it is 2048
Max vertex attribute stride unknown. Assuming it is 2048


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01,\x00\x00\x007\x08\x06\x00\x00\x00\xb6\x1bw\x99\x…

Valid,Device,Type,Backend,Driver
✅ (default),Quadro RTX 8000,DiscreteGPU,Vulkan,565.57.01
✅,Quadro RTX 8000,DiscreteGPU,Vulkan,565.57.01
❌,Quadro RTX 8000/PCIe/SSE2,Unknown,OpenGL,3.3.0 NVIDIA 565.57.01


Max vertex attribute stride unknown. Assuming it is 2048
Max vertex attribute stride unknown. Assuming it is 2048


In [2]:
class MotionBinDataset:
    """Load a memmapped suite2p data.bin together with metadata (.npy/.zip)."""
    def __init__(self, data_path: str, metadata_path: str, dtype=np.int16):
        self.bin_path = Path(data_path)
        self.ops_path = Path(metadata_path)
        self._dtype = dtype
        self._shape = self._compute_shape()
        self.data = np.memmap(self.bin_path, mode='r', dtype=self._dtype, shape=self._shape)

    def _compute_shape(self) -> Tuple[int, int, int]:
        _, ext = os.path.splitext(self.ops_path)
        if ext == ".zip":
            ops = np.load(self.ops_path, allow_pickle=True)['ops'].item()
        elif ext == ".npy":
            ops = np.load(self.ops_path, allow_pickle=True).item()
        else:
            raise ValueError("Metadata file must be .zip or .npy")
        return int(ops['nframes']), int(ops['Ly']), int(ops['Lx'])

    @property
    def shape(self) -> Tuple[int, int, int]:
        return self._shape

    def __getitem__(self, item):
        return self.data[item].copy()

config = {
    'bin_file_path': '/burg-archive/home/lm3879/plane4/data.bin',
    'ops_file_path': '/burg-archive/home/lm3879/plane4/ops.npy',
    'out_path': '/burg-archive/home/lm3879/masknmf-toolbox/ibl_denoised_output/pmd_spatial_results.npz',
    'block_size_dim1': 32,
    'block_size_dim2': 32,
    'max_components': 20,
    'max_consecutive_failures': 1,
    'spatial_avg_factor': 1,
    'temporal_avg_factor': 1,
    'device': 'cuda',
    'frame_batch_size': 1024,
    'denoiser_max_epochs': 5,
    'denoiser_batch_size': 32,
    'denoiser_lr': 1e-4,
    'patch_h': 40,
    'patch_w': 40,
}

cfg = OmegaConf.create(config)
device = cfg.device if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

my_data = MotionBinDataset(cfg.bin_file_path, cfg.ops_file_path)
print(f"Loaded data with shape (T,H,W): {my_data.shape}")
binary_mask = np.zeros((my_data.shape[1], my_data.shape[2]), dtype=np.float32)
binary_mask[3:-3, 3:-3] = 1.0

Using device: cuda
Loaded data with shape (T,H,W): (11700, 500, 620)


In [5]:
compress_strategy = masknmf.CompressTemporalDenoiseStrategy(
    dataset=my_data,
    block_sizes=(cfg.block_size_dim1, cfg.block_size_dim2),
    noise_variance_quantile=0.7,
    num_epochs=5
)
pmd_result_temporal = compress_strategy.compress()
print(f"Final PMD rank: {pmd_result_temporal.pmd_rank}")
data = my_data.data
pmd_result_temporal.to('cuda')
pmd_residual_temporal = masknmf.PMDResidualArray(my_data, pmd_result_temporal)
iw = fpl.ImageWidget(
    data=[data, pmd_result_temporal, pmd_residual_temporal], 
    figure_shape=(1, 3),
    names=['Original', 'PMD (Temporal Denoised)', 'Residual']
)

iw.show()

NameError: name 'masknmf' is not defined

In [None]:
block_sizes = [32, 32]

compress_strategy = CompressSpatialDenoiseStrategy(
    my_data, 
    block_sizes=block_sizes,
    max_components=20,
    spatial_denoiser_epochs=5,
    spatial_denoiser_batch_size=32,
    spatial_denoiser_lr=1e-4,
    noise_variance_quantile=0.7,
    patch_h=40,
    patch_w=40,
    device='cuda'  # or 'cpu'
)

pmd_result_spatial = compress_strategy.compress()

print(f"Final PMD rank: {pmd_result_spatial.pmd_rank}")


data = my_data.data
pmd_result_spatial.to('cuda')
pmd_residual_spatial = masknmf.PMDResidualArray(my_data, pmd_result_spatial)

iw = fpl.ImageWidget(
    data=[data, pmd_result_spatial, pmd_residual_spatial], 
    figure_shape=(1, 3),
    names=['Original', 'PMD (Spatial Denoised)', 'Residual']
)

iw.show()

In [None]:
compress_strategy_both = CompressSpatialTemporalDenoiseStrategy(
    data,
    block_sizes=block_sizes,
    max_components=20,
    # Spatial denoiser params
    spatial_denoiser_epochs=5,
    spatial_denoiser_batch_size=32,
    spatial_denoiser_lr=1e-4,
    spatial_noise_variance_quantile=0.7,
    patch_h=40,
    patch_w=40,
    # Temporal denoiser params
    temporal_denoiser_epochs=5,
    temporal_denoiser_batch_size=128,
    temporal_denoiser_lr=1e-4,
    temporal_noise_variance_quantile=0.7,
    device='cuda'
)

pmd_result_both = compress_strategy_both.compress()

print(f"Final PMD rank (both denoisers): {pmd_result_both.pmd_rank}")

# Visualize
pmd_result_both.to('cuda')
pmd_residual_both = masknmf.PMDResidualArray(data, pmd_result_both)
iw2 = fpl.ImageWidget(
    data=[data, pmd_result_both, pmd_residual_both],
    figure_shape=(1, 3),
    names=['Original', 'PMD (Spatial+Temporal Denoised)', 'Residual']
)

iw2.show()

In [None]:
iw = fpl.ImageWidget(
    data=[data, pmd_result_temporal, pmd_residual_temporal, data, pmd_result_sptial, pmd_residual, data, pmd_result_both, pmd_residual_both], 
    figure_shape=(3, 3),
    names=['Original', 'PMD (Temporal Denoised)', 'Temporal Residual', 'Original ', 'PMD (Spatial Denoised)', 'Spatial Residual', ' Original ', 'PMD (Spatial + Temporal)', 'Residual']
)

iw.show()

In [None]:
import matplotlib.pyplot as plt

class SpatialComponentComparison:
    """Compare spatial components before and after denoising."""
    
    def __init__(self, 
                 pmd_before, 
                 pmd_after, 
                 data_shape,
                 top_k=12,
                 zoom_padding=2):
        """
        Parameters
        ----------
        pmd_before : PMDArray
            PMD result before denoising
        pmd_after : PMDArray
            PMD result after denoising
        data_shape : tuple
            (T, H, W) shape of original data
        top_k : int
            Number of top components to display
        zoom_padding : int
            Padding around detected regions
        """
        self.pmd_before = pmd_before
        self.pmd_after = pmd_after
        self.H, self.W = data_shape[1], data_shape[2]
        self.top_k = top_k
        self.zoom_padding = zoom_padding
        
        # Extract and reshape spatial components
        self.u_before = self._extract_components(pmd_before)
        self.u_after = self._extract_components(pmd_after)
        
    def _extract_components(self, pmd):
        """Extract U matrix and reshape to (H, W, rank)."""
        u_dense = pmd.u.to_dense().cpu()
        if isinstance(u_dense, torch.Tensor):
            u_dense = u_dense.numpy()
        return u_dense.reshape(self.H, self.W, -1)
    
    def _compute_energy(self, u):
        """Compute L2 norm for each component."""
        return np.array([np.linalg.norm(u[:, :, i]) for i in range(u.shape[2])])
    
    def _detect_bbox(self, img, threshold_quantile=0.01):
        """Detect bounding box of active region."""
        img_abs = np.abs(img)
        
        if np.any(img_abs > 0):
            threshold = np.quantile(img_abs[img_abs > 0], threshold_quantile)
        else:
            return 0, img.shape[0], 0, img.shape[1]
        
        mask = img_abs > threshold
        if not np.any(mask):
            return 0, img.shape[0], 0, img.shape[1]
        
        rows = np.where(np.any(mask, axis=1))[0]
        cols = np.where(np.any(mask, axis=0))[0]
        
        if len(rows) == 0 or len(cols) == 0:
            return 0, img.shape[0], 0, img.shape[1]
        
        return rows[0], rows[-1] + 1, cols[0], cols[-1] + 1
    
    def _get_union_bbox(self, img1, img2):
        """Get union bounding box with padding."""
        bbox1 = self._detect_bbox(img1)
        bbox2 = self._detect_bbox(img2)
        
        row_min = max(0, min(bbox1[0], bbox2[0]) - self.zoom_padding)
        row_max = min(self.H, max(bbox1[1], bbox2[1]) + self.zoom_padding)
        col_min = max(0, min(bbox1[2], bbox2[2]) - self.zoom_padding)
        col_max = min(self.W, max(bbox1[3], bbox2[3]) + self.zoom_padding)
        
        return row_min, row_max, col_min, col_max
    
    def _match_components(self):
        """Match before/after components by correlation."""
        # Select top energy components
        energy_before = self._compute_energy(self.u_before)
        energy_after = self._compute_energy(self.u_after)
        
        n_before = min(400, len(energy_before))
        n_after = min(400, len(energy_after))
        
        idx_before = np.argsort(energy_before)[-n_before:][::-1]
        idx_after = np.argsort(energy_after)[-n_after:][::-1]
        
        # Normalize and compute correlation
        B = self.u_before.reshape(-1, self.u_before.shape[2])[:, idx_before].astype(np.float32)
        A = self.u_after.reshape(-1, self.u_after.shape[2])[:, idx_after].astype(np.float32)
        
        # Normalize columns
        B_norm = B / (np.linalg.norm(B, axis=0, keepdims=True) + 1e-8)
        A_norm = A / (np.linalg.norm(A, axis=0, keepdims=True) + 1e-8)
        
        # Compute correlation
        corr = np.abs(B_norm.T @ A_norm)
        
        # Match top_k components
        matches = []
        for i in range(min(self.top_k, len(idx_before))):
            j = np.argmax(corr[i, :])
            matches.append((idx_before[i], idx_after[j], corr[i, j]))
        
        return matches
    
    def visualize(self, save_prefix='component_comparison', components_per_figure=6):
        """
        Create comparison visualizations.
        
        Parameters
        ----------
        save_prefix : str
            Prefix for saved figure files
        components_per_figure : int
            Number of component comparisons per figure
        """
        print("\n" + "="*60)
        print("Comparing Spatial Components: Before vs After Denoising")
        print("="*60)
        
        matches = self._match_components()
        
        print(f"\nRank before: {self.u_before.shape[2]}")
        print(f"Rank after: {self.u_after.shape[2]}")
        print(f"\nTop {len(matches)} component matches:")
        for i, (idx_b, idx_a, corr) in enumerate(matches):
            print(f"  {i+1}. Before #{idx_b} → After #{idx_a}  (corr={corr:.3f})")
        
        # Prepare cropped data
        viz_data = []
        for idx_b, idx_a, corr in matches:
            before_img = self.u_before[:, :, idx_b]
            after_img = self.u_after[:, :, idx_a]
            
            # Align sign
            if np.sum(before_img * after_img) < 0:
                after_img = -after_img
            
            # Get crop region
            bbox = self._get_union_bbox(before_img, after_img)
            row_min, row_max, col_min, col_max = bbox
            
            # Crop
            b_crop = before_img[row_min:row_max, col_min:col_max]
            a_crop = after_img[row_min:row_max, col_min:col_max]
            
            viz_data.append({
                'before': b_crop,
                'after': a_crop,
                'diff': a_crop - b_crop,
                'idx_b': idx_b,
                'idx_a': idx_a,
                'corr': corr,
                'bbox': bbox
            })
        
        # Create figures
        n_figures = int(np.ceil(len(viz_data) / components_per_figure))
        
        for fig_idx in range(n_figures):
            start = fig_idx * components_per_figure
            end = min((fig_idx + 1) * components_per_figure, len(viz_data))
            
            self._plot_comparison(viz_data[start:end], 
                                save_path=f'{save_prefix}_{fig_idx+1}_of_{n_figures}.png')
        
        print(f"\n✓ Created {n_figures} comparison figure(s)")
        print("="*60 + "\n")
    
    def _plot_comparison(self, viz_data, save_path):
        """Plot comparison grid for a subset of components."""
        n = len(viz_data)
        fig, axes = plt.subplots(n, 3, figsize=(12, 3*n))
        
        if n == 1:
            axes = axes.reshape(1, -1)
        
        for i, data in enumerate(viz_data):
            b_crop = data['before']
            a_crop = data['after']
            diff = data['diff']
            
            vmax = max(np.abs(b_crop).max(), np.abs(a_crop).max())
            vmax_diff = np.abs(diff).max() if diff.size > 0 else 1
            
            # Get bbox for title
            row_min, row_max, col_min, col_max = data['bbox']
            
            # Before
            im0 = axes[i, 0].imshow(b_crop, cmap='gray', vmin=-vmax, vmax=vmax)
            axes[i, 0].axis('off')
            axes[i, 0].set_title(f"Before #{data['idx_b']}\n[{row_min}:{row_max}, {col_min}:{col_max}]", 
                               fontsize=9)
            plt.colorbar(im0, ax=axes[i, 0], fraction=0.046, pad=0.04)
            
            # After
            im1 = axes[i, 1].imshow(a_crop, cmap='gray', vmin=-vmax, vmax=vmax)
            axes[i, 1].axis('off')
            axes[i, 1].set_title(f"After #{data['idx_a']}\nCorr={data['corr']:.3f}", 
                               fontsize=9)
            plt.colorbar(im1, ax=axes[i, 1], fraction=0.046, pad=0.04)
            
            # Difference
            im2 = axes[i, 2].imshow(diff, cmap='seismic', vmin=-vmax_diff, vmax=vmax_diff)
            axes[i, 2].axis('off')
            mse = np.mean(diff**2)
            axes[i, 2].set_title(f"Difference\nMSE={mse:.2e}", fontsize=9)
            plt.colorbar(im2, ax=axes[i, 2], fraction=0.046, pad=0.04)
        
        plt.tight_layout()
        fig.savefig(save_path, dpi=200, bbox_inches='tight')
        print(f"  Saved: {save_path}")
        plt.close(fig)

pmd_no_denoise = masknmf.compression.pmd_decomposition(
    my_data,
    [cfg.block_size_dim1, cfg.block_size_dim2],
    my_data.shape[0],
    max_components=cfg.max_components,
    max_consecutive_failures=cfg.max_consecutive_failures,
    temporal_avg_factor=cfg.temporal_avg_factor,
    spatial_avg_factor=cfg.spatial_avg_factor,
    device=cfg.device,
    temporal_denoiser=None,
    frame_batch_size=cfg.frame_batch_size,
    pixel_weighting=binary_mask
)
print(f"PMD rank: {pmd_no_denoise.pmd_rank}")
# Extract spatial components
print("\n" + "="*60)
print("Extracting spatial components from PMD results...")
print("="*60)

comparator = SpatialComponentComparison(
    pmd_before=pmd_no_denoise,
    pmd_after=pmd_result_spatial,
    data_shape=my_data.shape,
    top_k=12,
    zoom_padding=2
)

comparator.visualize(
    save_prefix='spatial_comparison',
    components_per_figure=6
)