In [1]:
"""
Noise Models for 4DSTEM, including Low-Dose Direct Detector

This module provides noise models that simulate typical noise types in electorn microscopy,
including modern direct electron detectors. They can be applied directly to a py4DSTEM datacube:

- Poisson (shot noise)
- Readout Gaussian
- Dark current
- Salt and Pepper
- Correlated
- Drizzle near bright
- Very sparse signal (few electron counts per pattern)
- Homogeneous zero background
- Single electron counting statistics

"""

import numpy as np
from abc import ABC, abstractmethod
from typing import Optional, Dict, Any, Tuple
import py4DSTEM

# ============================================================================
# Abstract base class for noise models
# ============================================================================
class NoiseModel(ABC):
    """Base class for noise models"""
    
    @abstractmethod
    def apply(self, data: np.ndarray, **kwargs) -> np.ndarray:
        """
        Apply noise to the data
        
        Parameters
        ----------
        data : np.ndarray
            Input data array
        **kwargs : additional parameters for the noise model
        
        Returns
        -------
        np.ndarray
            Data with noise added
        """
        pass

# ============================================================================
# Concrete noise model implementations. More models can be added if necessary.
# ============================================================================
class PoissonNoise(NoiseModel):
    """Shot noise following Poisson statistics"""
    
    def apply(self, data: np.ndarray, scale: float = 1.0) -> np.ndarray:
        """
        Apply Poisson noise
        
        Parameters
        ----------
        data : np.ndarray
            Input data (should be non-negative)
        scale : float
            Scaling factor for intensity before applying Poisson
        
        Returns
        -------
        np.ndarray
            Data with Poisson noise
        """
        # Scale data, apply Poisson, scale back
        scaled = data * scale
        noisy = np.random.poisson(scaled)
        return noisy / scale


class GaussianNoise(NoiseModel):
    """Additive Gaussian (normal) noise"""
    
    def apply(self, data: np.ndarray, mean: float = 0.0, 
              sigma: float = 1.0) -> np.ndarray:
        """
        Apply Gaussian noise
        
        Parameters
        ----------
        data : np.ndarray
            Input data
        mean : float
            Mean of Gaussian distribution
        sigma : float
            Standard deviation of Gaussian distribution
        
        Returns
        -------
        np.ndarray
            Data with Gaussian noise added
        """
        noise = np.random.normal(mean, sigma, data.shape)
        return data + noise


class ReadoutNoise(NoiseModel):
    """Detector readout noise (Gaussian)"""
    
    def apply(self, data: np.ndarray, sigma: float = 5.0) -> np.ndarray:
        """
        Apply readout noise
        
        Parameters
        ----------
        data : np.ndarray
            Input data
        sigma : float
            Standard deviation of readout noise in counts
        
        Returns
        -------
        np.ndarray
            Data with readout noise added
        """
        noise = np.random.normal(0, sigma, data.shape)
        return data + noise


class DarkCurrentNoise(NoiseModel):
    """Dark current noise (Poisson-distributed)"""
    
    def apply(self, data: np.ndarray, dark_current: float = 1.0) -> np.ndarray:
        """
        Apply dark current noise
        
        Parameters
        ----------
        data : np.ndarray
            Input data
        dark_current : float
            Mean dark current in counts per pixel
        
        Returns
        -------
        np.ndarray
            Data with dark current noise added
        """
        dark = np.random.poisson(dark_current, data.shape)
        return data + dark


class SaltPepperNoise(NoiseModel):
    """Salt and pepper (impulse) noise"""
    
    def apply(self, data: np.ndarray, probability: float = 0.01,
              salt_value: Optional[float] = None,
              pepper_value: float = 0.0) -> np.ndarray:
        """
        Apply salt and pepper noise
        
        Parameters
        ----------
        data : np.ndarray
            Input data
        probability : float
            Probability of a pixel being affected (total for both salt and pepper)
        salt_value : float, optional
            Value for 'salt' pixels. If None, uses max of data
        pepper_value : float
            Value for 'pepper' pixels
        
        Returns
        -------
        np.ndarray
            Data with salt and pepper noise
        """
        noisy = data.copy()
        
        if salt_value is None:
            salt_value = np.max(data)
        
        # Salt noise
        salt_mask = np.random.random(data.shape) < (probability / 2)
        noisy[salt_mask] = salt_value
        
        # Pepper noise
        pepper_mask = np.random.random(data.shape) < (probability / 2)
        noisy[pepper_mask] = pepper_value
        
        return noisy


class CorrelatedNoise(NoiseModel):
    """Spatially correlated noise (low-frequency)"""
    
    def apply(self, data: np.ndarray, sigma: float = 1.0,
              correlation_length: float = 5.0) -> np.ndarray:
        """
        Apply correlated noise using Gaussian filtering
        
        Parameters
        ----------
        data : np.ndarray
            Input data
        sigma : float
            Amplitude of noise
        correlation_length : float
            Correlation length in pixels
        
        Returns
        -------
        np.ndarray
            Data with correlated noise added
        """
        from scipy.ndimage import gaussian_filter
        
        # Generate white noise
        white_noise = np.random.normal(0, sigma, data.shape)
        
        # Smooth to create correlations
        correlated = gaussian_filter(white_noise, sigma=correlation_length)
        
        return data + correlated

class DrizzleNearBrightPoissonNoise(NoiseModel):
    """
    For the brightest p% pixels, drizzle Poisson-distributed counts onto
    K randomly chosen pixels within a radius (<=radius_px) AND within a square window (square_side x square_side).

    This creates "correlated" salt-like noise near bright features.
    """

    def apply(
        self,
        data: np.ndarray,
        bright_fraction: float = 0.01,     # 1% brightest pixels
        radius_px: int = 5,                # within distance <= 5 pixels
        square_side: int = 10,             # within 10x10 window
        drizzles_per_seed: int = 3,        # drizzle onto 3 pixels per bright seed
        lam_fraction: float = 0.05,        # lambda = lam_fraction * seed_intensity
        lam_min: float = 1.0,              # minimum lambda
        exclude_center: bool = True,       # don't drizzle onto the seed pixel itself
        rng: Optional[np.random.Generator] = None,
    ) -> np.ndarray:
        if rng is None:
            rng = np.random.default_rng()

        img = np.asarray(data, dtype=np.float32, order="C")
        h, w = img.shape

        # --- choose brightest pixels (top bright_fraction) ---
        if not (0.0 < bright_fraction < 1.0):
            raise ValueError("bright_fraction must be in (0,1)")

        thr = np.quantile(img, 1.0 - bright_fraction)
        seeds = np.argwhere(img >= thr)  # array of (y, x)

        if seeds.size == 0:
            return img.copy()

        # --- candidate offsets in a square window, additionally constrained by radius ---
        half = square_side // 2
        # For square_side=10 -> offsets [-5,4] (10 values), matching your 10x10 wording
        ys = np.arange(-half, -half + square_side, dtype=int)
        xs = np.arange(-half, -half + square_side, dtype=int)
        dy, dx = np.meshgrid(ys, xs, indexing="ij")
        dy = dy.ravel()
        dx = dx.ravel()

        # radius constraint
        mask_r = (dy * dy + dx * dx) <= (radius_px * radius_px)
        if exclude_center:
            mask_r &= ~((dy == 0) & (dx == 0))

        dy = dy[mask_r]
        dx = dx[mask_r]
        n_candidates = dy.size
        if n_candidates == 0:
            return img.copy()

        out = img.copy()

        # --- drizzle loop ---
        for (y, x) in seeds:
            seed_intensity = float(img[y, x])
            lam = max(lam_min, lam_fraction * seed_intensity)

            # pick K distinct offsets
            k = min(drizzles_per_seed, n_candidates)
            pick = rng.choice(n_candidates, size=k, replace=False)

            yy = y + dy[pick]
            xx = x + dx[pick]

            # in-bounds
            inb = (yy >= 0) & (yy < h) & (xx >= 0) & (xx < w)
            yy = yy[inb]
            xx = xx[inb]
            if yy.size == 0:
                continue

            # Poisson drizzle counts
            drizzle = rng.poisson(lam, size=yy.size).astype(np.float32)
            out[yy, xx] += drizzle

        return out

# ============================================================================
# Main function to add noise to datacube
# ============================================================================

def add_noise_to_datacube(
    datacube: py4DSTEM.DataCube,
    noise_models: list[tuple[NoiseModel, Dict[str, Any]]],
    seed: Optional[int] = None,
    clip_negative: bool = True,
    preserve_dtype: bool = True  # NEW parameter
) -> py4DSTEM.DataCube:
    """
    Add noise to Q-space (diffraction patterns) in a 4DSTEM datacube
    
    Parameters
    ----------
    datacube : py4DSTEM.DataCube
        Input datacube
    noise_models : list of tuples
        List of (NoiseModel instance, parameters dict) to apply sequentially
    seed : int, optional
        Random seed for reproducibility
    clip_negative : bool
        Whether to clip negative values to zero after adding noise
    
    Returns
    -------
    py4DSTEM.DataCube
        New datacube with noise added
    
    Examples
    --------
    >>> # Single noise source
    >>> noisy = add_noise_to_datacube(
    ...     datacube,
    ...     [(PoissonNoise(), {'scale': 100})]
    ... )
    
    >>> # Multiple noise sources
    >>> noisy = add_noise_to_datacube(
    ...     datacube,
    ...     [
    ...         (PoissonNoise(), {'scale': 100}),
    ...         (ReadoutNoise(), {'sigma': 5}),
    ...         (DarkCurrentNoise(), {'dark_current': 2})
    ...     ],
    ...     seed=42
    ... )
    """
    if seed is not None:
        np.random.seed(seed)
    
    # Store original dtype
    original_dtype = datacube.data.dtype
    
    # Copy the data
    noisy_data = datacube.data.copy().astype(float)
    
    # Get shape
    scan_i, scan_j, det_i, det_j = noisy_data.shape
    
    print(f"Adding noise to datacube of shape {noisy_data.shape}")
    
    # Apply each noise model sequentially to each diffraction pattern
    for noise_idx, (noise_model, params) in enumerate(noise_models):
        print(f"Applying {noise_model.__class__.__name__} with params {params}...")
        
        # Apply noise to each diffraction pattern in Q-space
        for i in range(scan_i):
            for j in range(scan_j):
                # Get diffraction pattern
                dp = noisy_data[i, j, :, :]
                
                # Apply noise
                noisy_dp = noise_model.apply(dp, **params)
                
                # Store back
                noisy_data[i, j, :, :] = noisy_dp
    
    # Clip negative values if requested
    if clip_negative:
        noisy_data = np.maximum(noisy_data, 0)
    
    # Convert back to original dtype to save space
    if preserve_dtype:
        # Get the max value for the dtype to avoid overflow
        if np.issubdtype(original_dtype, np.integer):
            dtype_max = np.iinfo(original_dtype).max
            noisy_data = np.clip(noisy_data, 0, dtype_max)
        noisy_data = noisy_data.astype(original_dtype)
    
    # Create new datacube
    noisy_datacube = py4DSTEM.DataCube(data=noisy_data)
    
    # Copy calibration if it exists
    if hasattr(datacube, 'calibration'):
        noisy_datacube.calibration = datacube.calibration
    
    # Store noise information in metadata
    noisy_datacube.metadata['noise_applied'] = [
        {
            'model': model.__class__.__name__,
            'parameters': params
        }
        for model, params in noise_models
    ]
    
    print("Noise addition complete!")
    return noisy_datacube


# ============================================================================
# Convenience function for common noise combinations
# ============================================================================

def add_realistic_detector_noise(
    datacube: py4DSTEM.DataCube,
    dose_scale: float = 100,
    readout_sigma: float = 5.0,
    dark_current: float = 1.0,
    seed: Optional[int] = None
) -> py4DSTEM.DataCube:
    """
    Add realistic detector noise (Poisson + readout + dark current)
    
    Parameters
    ----------
    datacube : py4DSTEM.DataCube
        Input datacube
    dose_scale : float
        Scaling for shot noise (higher = more signal, less relative noise)
    readout_sigma : float
        Readout noise standard deviation in counts
    dark_current : float
        Mean dark current in counts per pixel
    seed : int, optional
        Random seed
    
    Returns
    -------
    py4DSTEM.DataCube
        Datacube with realistic noise
    """
    noise_models = [
        (PoissonNoise(), {'scale': dose_scale}),
        (DarkCurrentNoise(), {'dark_current': dark_current}),
        (ReadoutNoise(), {'sigma': readout_sigma})
    ]
    
    return add_noise_to_datacube(datacube, noise_models, seed=seed)



# ============================================================================
# Low-Dose Specific Noise Models
# ============================================================================

class LowDoseDirectDetector(ABC):
    """Base class for low-dose direct detector simulation"""
    
    @abstractmethod
    def apply(self, data: np.ndarray, **kwargs) -> np.ndarray:
        pass


class SparsePoissonNoise(LowDoseDirectDetector):
    """
    Sparse Poisson noise for low-dose conditions
    
    Simulates electron counting with very low dose, resulting in:
    - Mostly zero pixels
    - Few bright pixels (electron hits)
    - Poisson statistics
    """
    
    def apply(
        self,
        data: np.ndarray,
        dose_fraction: float = 0.01,
        ensure_sparse: bool = True
    ) -> np.ndarray:
        """
        Apply sparse Poisson noise for low-dose imaging
        
        Parameters
        ----------
        data : np.ndarray
            Input data (will be used as intensity map)
        dose_fraction : float
            Fraction of total dose to use (0.01 = 1% of electrons)
            Lower = sparser signal
        ensure_sparse : bool
            If True, enforces that most pixels are zero
        
        Returns
        -------
        np.ndarray
            Sparse, low-dose data with Poisson noise
        """
        # Normalize input to get probability distribution
        data_norm = data.astype(float)
        data_norm = data_norm / (np.sum(data_norm) + 1e-10)
        
        # Calculate expected number of electrons for this pattern
        total_electrons = np.sum(data) * dose_fraction
        
        # Sample from Poisson distribution
        # Expected counts = normalized intensity Ã— total electrons
        expected_counts = data_norm * total_electrons
        
        # Apply Poisson noise
        noisy = np.random.poisson(expected_counts)
        
        if ensure_sparse:
            # Make sure we have mostly zeros
            sparsity = np.sum(noisy == 0) / noisy.size
            if sparsity < 0.9:  # Less than 90% zeros
                # Reduce dose further to ensure sparsity
                scale_factor = 0.5
                noisy = np.random.poisson(expected_counts * scale_factor)
        
        return noisy.astype(data.dtype)


class ElectronCountingNoise(LowDoseDirectDetector):
    """
    Electron counting detector noise with discrete electrons
    
    Simulates:
    - Individual electron events
    - Sparse distribution
    - Optional detector quantum efficiency
    """
    
    def apply(
        self,
        data: np.ndarray,
        electrons_per_pattern: float = 100,
        dqe: float = 1.0,
        spread_sigma: float = 0.0
    ) -> np.ndarray:
        """
        Apply electron counting noise
        
        Parameters
        ----------
        data : np.ndarray
            Input data (used as probability map for electron positions)
        electrons_per_pattern : float
            Average number of electrons to detect per pattern
        dqe : float
            Detective quantum efficiency (fraction of electrons detected)
            1.0 = perfect detector, 0.7 = typical
        spread_sigma : float
            Gaussian spread of electron signal (pixels)
            0 = point detector, >0 = some blur
        
        Returns
        -------
        np.ndarray
            Electron counting data
        """
        # Normalize to probability distribution
        prob_map = data.astype(float)
        prob_map = prob_map / (np.sum(prob_map) + 1e-10)
        
        # Sample number of electrons (Poisson distributed)
        n_electrons = np.random.poisson(electrons_per_pattern)
        
        # Apply DQE (some electrons not detected)
        n_detected = np.random.binomial(n_electrons, dqe)
        
        # Initialize output
        output = np.zeros_like(data, dtype=float)
        
        if n_detected > 0:
            # Flatten probability map for sampling
            prob_flat = prob_map.flatten()
            
            # Sample electron positions
            indices = np.random.choice(
                len(prob_flat),
                size=n_detected,
                p=prob_flat,
                replace=True
            )
            
            # Convert to 2D indices
            positions = np.unravel_index(indices, data.shape)
            
            if spread_sigma > 0:
                # Add some spread to electron positions (point spread function)
                for y, x in zip(*positions):
                    # Add Gaussian spread
                    y_spread = y + np.random.normal(0, spread_sigma)
                    x_spread = x + np.random.normal(0, spread_sigma)
                    
                    # Clip to valid range
                    y_int = int(np.clip(y_spread, 0, data.shape[0] - 1))
                    x_int = int(np.clip(x_spread, 0, data.shape[1] - 1))
                    
                    output[y_int, x_int] += 1
            else:
                # Point detector - just count electrons at each position
                for y, x in zip(*positions):
                    output[y, x] += 1
        
        return output.astype(data.dtype)


class BimodalSparseNoise(LowDoseDirectDetector):
    """
    Bimodal sparse noise: large spike at zero + Gaussian distribution for signal
    
    Creates histogram with:
    - Large bin at 0 (background, 95-99% of pixels)
    - Gaussian distribution centered at signal_mean for non-zero pixels
    - Clear separation between background and signal
    
    This is ideal for training denoisers on sparse data with distinct
    background and signal populations.
    """
    
    def apply(
        self,
        data: np.ndarray,
        sparsity_target: float = 0.98,
        signal_mean: float = 30.0,
        signal_sigma: float = 10.0,
        min_signal: float = 5.0
    ) -> np.ndarray:
        """
        Apply bimodal sparse noise
        
        Parameters
        ----------
        data : np.ndarray
            Input data (used as probability map for signal positions)
        sparsity_target : float
            Target fraction of zero pixels (0.98 = 98% zeros)
        signal_mean : float
            Mean value for non-zero signal pixels
        signal_sigma : float
            Standard deviation for signal distribution
        min_signal : float
            Minimum value for signal pixels (clips low values)
        
        Returns
        -------
        np.ndarray
            Bimodal sparse data with clear zero peak and signal distribution
        """
        # Normalize to probability distribution
        prob_map = data.astype(float)
        total_intensity = np.sum(prob_map)
        
        if total_intensity > 0:
            prob_map = prob_map / total_intensity
        else:
            prob_map = np.ones_like(prob_map) / prob_map.size
        
        # Calculate number of signal pixels
        total_pixels = data.size
        n_signal_pixels = int(total_pixels * (1 - sparsity_target))
        
        # Initialize output with zeros
        output = np.zeros_like(data, dtype=float)
        
        if n_signal_pixels > 0:
            # Sample positions for signal pixels
            prob_flat = prob_map.flatten()
            indices = np.random.choice(
                len(prob_flat),
                size=n_signal_pixels,
                p=prob_flat,
                replace=True  # Allow repeating pixels
            )
            
            # Generate signal values from Gaussian distribution
            signal_values = np.random.normal(signal_mean, signal_sigma, n_signal_pixels)
            
            # Clip to minimum signal value
            signal_values = np.maximum(signal_values, min_signal)
            
            # Assign to positions
            positions = np.unravel_index(indices, data.shape)
            output[positions] = signal_values
        
        return output.astype(data.dtype)


class LowDoseSparseModel(LowDoseDirectDetector):
    """
    Combined low-dose model with extreme sparsity
    
    Creates data that is:
    - >95% zeros
    - Sparse electron counts
    - Normally distributed signal in non-zero pixels
    - Homogeneous zero background
    """
    
    def apply(
        self,
        data: np.ndarray,
        sparsity_target: float = 0.98,
        mean_electrons: float = 50,
        add_readout: bool = False,
        readout_sigma: float = 1.0
    ) -> np.ndarray:
        """
        Apply extreme low-dose sparse noise model
        
        Parameters
        ----------
        data : np.ndarray
            Input data
        sparsity_target : float
            Target fraction of zero pixels (0.98 = 98% zeros)
        mean_electrons : float
            Mean number of electrons per pattern
        add_readout : bool
            Whether to add readout noise (usually False for counting detectors)
        readout_sigma : float
            Readout noise standard deviation (if added)
        
        Returns
        -------
        np.ndarray
            Extremely sparse low-dose data
        """
        # Normalize to probability
        prob_map = data.astype(float)
        total_intensity = np.sum(prob_map)
        
        if total_intensity > 0:
            prob_map = prob_map / total_intensity
        else:
            # If input is all zeros, use uniform probability
            prob_map = np.ones_like(prob_map) / prob_map.size
        
        # Sample number of electrons
        n_electrons = np.random.poisson(mean_electrons)
        
        # Calculate how many pixels to activate to achieve sparsity
        total_pixels = data.size
        target_active_pixels = int(total_pixels * (1 - sparsity_target))
        
        # Make sure we don't activate more pixels than we have electrons
        n_activate = min(n_electrons, target_active_pixels)
        
        # Sample positions
        output = np.zeros_like(data, dtype=float)
        
        if n_activate > 0:
            prob_flat = prob_map.flatten()
            indices = np.random.choice(
                len(prob_flat),
                size=n_activate,
                p=prob_flat,
                replace=True
            )
            
            # Add counts
            positions = np.unravel_index(indices, data.shape)
            for y, x in zip(*positions):
                output[y, x] += 1
        
        # Optionally add minimal readout noise
        if add_readout:
            readout = np.random.normal(0, readout_sigma, data.shape)
            output = output + readout
            output = np.maximum(output, 0)  # Clip negatives
        
        return output.astype(data.dtype)


# ============================================================================
# Dose Scaling Functions
# ============================================================================

def reduce_dose(
    datacube: py4DSTEM.DataCube,
    dose_fraction: float = 0.01,
    method: str = 'sparse_poisson',
    signal_mean: float = 30.0,
    signal_sigma: float = 10.0
) -> py4DSTEM.DataCube:
    """
    Reduce dose of datacube to simulate low-dose conditions
    
    Parameters
    ----------
    datacube : py4DSTEM.DataCube
        Input datacube with normal dose
    dose_fraction : float
        Fraction of original dose (0.01 = 1%)
    method : str
        Method to use:
        - 'sparse_poisson': Sparse Poisson sampling
        - 'electron_counting': Discrete electron counting
        - 'extreme_sparse': Extreme sparsity (>95% zeros)
        - 'bimodal': Spike at zero + Gaussian signal distribution (RECOMMENDED)
    signal_mean : float
        Mean value for signal pixels (only for 'bimodal' method)
    signal_sigma : float
        Std dev for signal pixels (only for 'bimodal' method)
    
    Returns
    -------
    py4DSTEM.DataCube
        Low-dose datacube
    
    Examples
    --------
    >>> # Bimodal distribution (spike at 0 + Gaussian signal)
    >>> low_dose = reduce_dose(datacube, dose_fraction=0.02, 
    ...                         method='bimodal', signal_mean=30)
    
    >>> # 1% dose with sparse Poisson
    >>> low_dose = reduce_dose(datacube, dose_fraction=0.01)
    
    >>> # Discrete electron counting (~100 electrons/pattern)
    >>> low_dose = reduce_dose(datacube, dose_fraction=0.01, 
    ...                         method='electron_counting')
    
    >>> # Extremely sparse (98% zeros)
    >>> low_dose = reduce_dose(datacube, dose_fraction=0.005,
    ...                         method='extreme_sparse')
    """
    data = datacube.data
    scan_i, scan_j, det_i, det_j = data.shape
    
    print(f"Reducing dose to {dose_fraction*100:.2f}% using {method}")
    
    # Select model
    if method == 'bimodal':
        # Bimodal: spike at zero + Gaussian signal
        sparsity = 1 - dose_fraction  # dose_fraction = fraction of non-zero pixels
        model = BimodalSparseNoise()
        params = {
            'sparsity_target': sparsity,
            'signal_mean': signal_mean,
            'signal_sigma': signal_sigma,
            'min_signal': 5.0
        }
    
    elif method == 'sparse_poisson':
        model = SparsePoissonNoise()
        params = {'dose_fraction': dose_fraction, 'ensure_sparse': True}
    
    elif method == 'electron_counting':
        # Calculate electrons per pattern from dose fraction
        avg_intensity = np.mean(data)
        electrons = avg_intensity * dose_fraction * det_i * det_j
        model = ElectronCountingNoise()
        params = {'electrons_per_pattern': electrons, 'dqe': 0.95}
    
    elif method == 'extreme_sparse':
        avg_intensity = np.mean(data)
        electrons = avg_intensity * dose_fraction * det_i * det_j * 0.1
        model = LowDoseSparseModel()
        params = {
            'sparsity_target': 0.98,
            'mean_electrons': electrons,
            'add_readout': False
        }
    
    else:
        raise ValueError(f"Unknown method: {method}")
    
    # Apply to all patterns
    low_dose_data = np.zeros_like(data)
    
    total = scan_i * scan_j
    for idx, (i, j) in enumerate(np.ndindex(scan_i, scan_j)):
        dp = data[i, j, :, :]
        low_dose_dp = model.apply(dp, **params)
        low_dose_data[i, j, :, :] = low_dose_dp
        
        if (idx + 1) % max(1, total // 10) == 0:
            print(f"  Progress: {idx+1}/{total}")
    
    # Create new datacube
    low_dose_cube = py4DSTEM.DataCube(data=low_dose_data)
    
    if hasattr(datacube, 'calibration'):
        low_dose_cube.calibration = datacube.calibration
    
    # Store metadata
    low_dose_cube.metadata['dose_reduction'] = {
        'method': method,
        'dose_fraction': dose_fraction,
        'original_mean': float(np.mean(data)),
        'reduced_mean': float(np.mean(low_dose_data)),
        'sparsity': float(np.sum(low_dose_data == 0) / low_dose_data.size)
    }
    
    print(f"Done! Sparsity: {low_dose_cube.metadata['dose_reduction']['sparsity']*100:.1f}% zeros")
    
    return low_dose_cube