In [None]:
"""
Low-Dose Direct Detector Noise Models for 4DSTEM

This module provides noise models that simulate low-dose direct detector conditions:
- Very sparse signal (few electron counts per pattern)
- Poisson-dominated noise (shot noise)
- Homogeneous zero background
- Single electron counting statistics

These models are optimized for training deep learning denoisers on sparse data.
"""

import numpy as np
from abc import ABC, abstractmethod
from typing import Optional, Dict, Any, Tuple
import py4DSTEM


# ============================================================================
# Low-Dose Specific Noise Models
# ============================================================================

class LowDoseDirectDetector(ABC):
    """Base class for low-dose direct detector simulation"""
    
    @abstractmethod
    def apply(self, data: np.ndarray, **kwargs) -> np.ndarray:
        pass


class SparsePoissonNoise(LowDoseDirectDetector):
    """
    Sparse Poisson noise for low-dose conditions
    
    Simulates electron counting with very low dose, resulting in:
    - Mostly zero pixels
    - Few bright pixels (electron hits)
    - Poisson statistics
    """
    
    def apply(
        self,
        data: np.ndarray,
        dose_fraction: float = 0.01,
        ensure_sparse: bool = True
    ) -> np.ndarray:
        """
        Apply sparse Poisson noise for low-dose imaging
        
        Parameters
        ----------
        data : np.ndarray
            Input data (will be used as intensity map)
        dose_fraction : float
            Fraction of total dose to use (0.01 = 1% of electrons)
            Lower = sparser signal
        ensure_sparse : bool
            If True, enforces that most pixels are zero
        
        Returns
        -------
        np.ndarray
            Sparse, low-dose data with Poisson noise
        """
        # Normalize input to get probability distribution
        data_norm = data.astype(float)
        data_norm = data_norm / (np.sum(data_norm) + 1e-10)
        
        # Calculate expected number of electrons for this pattern
        total_electrons = np.sum(data) * dose_fraction
        
        # Sample from Poisson distribution
        # Expected counts = normalized intensity Ã— total electrons
        expected_counts = data_norm * total_electrons
        
        # Apply Poisson noise
        noisy = np.random.poisson(expected_counts)
        
        if ensure_sparse:
            # Make sure we have mostly zeros
            sparsity = np.sum(noisy == 0) / noisy.size
            if sparsity < 0.9:  # Less than 90% zeros
                # Reduce dose further to ensure sparsity
                scale_factor = 0.5
                noisy = np.random.poisson(expected_counts * scale_factor)
        
        return noisy.astype(data.dtype)


class ElectronCountingNoise(LowDoseDirectDetector):
    """
    Electron counting detector noise with discrete electrons
    
    Simulates:
    - Individual electron events
    - Sparse distribution
    - Optional detector quantum efficiency
    """
    
    def apply(
        self,
        data: np.ndarray,
        electrons_per_pattern: float = 100,
        dqe: float = 1.0,
        spread_sigma: float = 0.0
    ) -> np.ndarray:
        """
        Apply electron counting noise
        
        Parameters
        ----------
        data : np.ndarray
            Input data (used as probability map for electron positions)
        electrons_per_pattern : float
            Average number of electrons to detect per pattern
        dqe : float
            Detective quantum efficiency (fraction of electrons detected)
            1.0 = perfect detector, 0.7 = typical
        spread_sigma : float
            Gaussian spread of electron signal (pixels)
            0 = point detector, >0 = some blur
        
        Returns
        -------
        np.ndarray
            Electron counting data
        """
        # Normalize to probability distribution
        prob_map = data.astype(float)
        prob_map = prob_map / (np.sum(prob_map) + 1e-10)
        
        # Sample number of electrons (Poisson distributed)
        n_electrons = np.random.poisson(electrons_per_pattern)
        
        # Apply DQE (some electrons not detected)
        n_detected = np.random.binomial(n_electrons, dqe)
        
        # Initialize output
        output = np.zeros_like(data, dtype=float)
        
        if n_detected > 0:
            # Flatten probability map for sampling
            prob_flat = prob_map.flatten()
            
            # Sample electron positions
            indices = np.random.choice(
                len(prob_flat),
                size=n_detected,
                p=prob_flat,
                replace=True
            )
            
            # Convert to 2D indices
            positions = np.unravel_index(indices, data.shape)
            
            if spread_sigma > 0:
                # Add some spread to electron positions (point spread function)
                for y, x in zip(*positions):
                    # Add Gaussian spread
                    y_spread = y + np.random.normal(0, spread_sigma)
                    x_spread = x + np.random.normal(0, spread_sigma)
                    
                    # Clip to valid range
                    y_int = int(np.clip(y_spread, 0, data.shape[0] - 1))
                    x_int = int(np.clip(x_spread, 0, data.shape[1] - 1))
                    
                    output[y_int, x_int] += 1
            else:
                # Point detector - just count electrons at each position
                for y, x in zip(*positions):
                    output[y, x] += 1
        
        return output.astype(data.dtype)


class BimodalSparseNoise(LowDoseDirectDetector):
    """
    Bimodal sparse noise: large spike at zero + Gaussian distribution for signal
    
    Creates histogram with:
    - Large bin at 0 (background, 95-99% of pixels)
    - Gaussian distribution centered at signal_mean for non-zero pixels
    - Clear separation between background and signal
    
    This is ideal for training denoisers on sparse data with distinct
    background and signal populations.
    """
    
    def apply(
        self,
        data: np.ndarray,
        sparsity_target: float = 0.98,
        signal_mean: float = 30.0,
        signal_sigma: float = 10.0,
        min_signal: float = 5.0
    ) -> np.ndarray:
        """
        Apply bimodal sparse noise
        
        Parameters
        ----------
        data : np.ndarray
            Input data (used as probability map for signal positions)
        sparsity_target : float
            Target fraction of zero pixels (0.98 = 98% zeros)
        signal_mean : float
            Mean value for non-zero signal pixels
        signal_sigma : float
            Standard deviation for signal distribution
        min_signal : float
            Minimum value for signal pixels (clips low values)
        
        Returns
        -------
        np.ndarray
            Bimodal sparse data with clear zero peak and signal distribution
        """
        # Normalize to probability distribution
        prob_map = data.astype(float)
        total_intensity = np.sum(prob_map)
        
        if total_intensity > 0:
            prob_map = prob_map / total_intensity
        else:
            prob_map = np.ones_like(prob_map) / prob_map.size
        
        # Calculate number of signal pixels
        total_pixels = data.size
        n_signal_pixels = int(total_pixels * (1 - sparsity_target))
        
        # Initialize output with zeros
        output = np.zeros_like(data, dtype=float)
        
        if n_signal_pixels > 0:
            # Sample positions for signal pixels
            prob_flat = prob_map.flatten()
            indices = np.random.choice(
                len(prob_flat),
                size=n_signal_pixels,
                p=prob_flat,
                replace=True  # Allow repeating pixels
            )
            
            # Generate signal values from Gaussian distribution
            signal_values = np.random.normal(signal_mean, signal_sigma, n_signal_pixels)
            
            # Clip to minimum signal value
            signal_values = np.maximum(signal_values, min_signal)
            
            # Assign to positions
            positions = np.unravel_index(indices, data.shape)
            output[positions] = signal_values
        
        return output.astype(data.dtype)


class LowDoseSparseModel(LowDoseDirectDetector):
    """
    Combined low-dose model with extreme sparsity
    
    Creates data that is:
    - >95% zeros
    - Sparse electron counts
    - Normally distributed signal in non-zero pixels
    - Homogeneous zero background
    """
    
    def apply(
        self,
        data: np.ndarray,
        sparsity_target: float = 0.98,
        mean_electrons: float = 50,
        add_readout: bool = False,
        readout_sigma: float = 1.0
    ) -> np.ndarray:
        """
        Apply extreme low-dose sparse noise model
        
        Parameters
        ----------
        data : np.ndarray
            Input data
        sparsity_target : float
            Target fraction of zero pixels (0.98 = 98% zeros)
        mean_electrons : float
            Mean number of electrons per pattern
        add_readout : bool
            Whether to add readout noise (usually False for counting detectors)
        readout_sigma : float
            Readout noise standard deviation (if added)
        
        Returns
        -------
        np.ndarray
            Extremely sparse low-dose data
        """
        # Normalize to probability
        prob_map = data.astype(float)
        total_intensity = np.sum(prob_map)
        
        if total_intensity > 0:
            prob_map = prob_map / total_intensity
        else:
            # If input is all zeros, use uniform probability
            prob_map = np.ones_like(prob_map) / prob_map.size
        
        # Sample number of electrons
        n_electrons = np.random.poisson(mean_electrons)
        
        # Calculate how many pixels to activate to achieve sparsity
        total_pixels = data.size
        target_active_pixels = int(total_pixels * (1 - sparsity_target))
        
        # Make sure we don't activate more pixels than we have electrons
        n_activate = min(n_electrons, target_active_pixels)
        
        # Sample positions
        output = np.zeros_like(data, dtype=float)
        
        if n_activate > 0:
            prob_flat = prob_map.flatten()
            indices = np.random.choice(
                len(prob_flat),
                size=n_activate,
                p=prob_flat,
                replace=True
            )
            
            # Add counts
            positions = np.unravel_index(indices, data.shape)
            for y, x in zip(*positions):
                output[y, x] += 1
        
        # Optionally add minimal readout noise
        if add_readout:
            readout = np.random.normal(0, readout_sigma, data.shape)
            output = output + readout
            output = np.maximum(output, 0)  # Clip negatives
        
        return output.astype(data.dtype)


# ============================================================================
# Dose Scaling Functions
# ============================================================================

def reduce_dose(
    datacube: py4DSTEM.DataCube,
    dose_fraction: float = 0.01,
    method: str = 'sparse_poisson',
    signal_mean: float = 30.0,
    signal_sigma: float = 10.0
) -> py4DSTEM.DataCube:
    """
    Reduce dose of datacube to simulate low-dose conditions
    
    Parameters
    ----------
    datacube : py4DSTEM.DataCube
        Input datacube with normal dose
    dose_fraction : float
        Fraction of original dose (0.01 = 1%)
    method : str
        Method to use:
        - 'sparse_poisson': Sparse Poisson sampling
        - 'electron_counting': Discrete electron counting
        - 'extreme_sparse': Extreme sparsity (>95% zeros)
        - 'bimodal': Spike at zero + Gaussian signal distribution (RECOMMENDED)
    signal_mean : float
        Mean value for signal pixels (only for 'bimodal' method)
    signal_sigma : float
        Std dev for signal pixels (only for 'bimodal' method)
    
    Returns
    -------
    py4DSTEM.DataCube
        Low-dose datacube
    
    Examples
    --------
    >>> # Bimodal distribution (spike at 0 + Gaussian signal)
    >>> low_dose = reduce_dose(datacube, dose_fraction=0.02, 
    ...                         method='bimodal', signal_mean=30)
    
    >>> # 1% dose with sparse Poisson
    >>> low_dose = reduce_dose(datacube, dose_fraction=0.01)
    
    >>> # Discrete electron counting (~100 electrons/pattern)
    >>> low_dose = reduce_dose(datacube, dose_fraction=0.01, 
    ...                         method='electron_counting')
    
    >>> # Extremely sparse (98% zeros)
    >>> low_dose = reduce_dose(datacube, dose_fraction=0.005,
    ...                         method='extreme_sparse')
    """
    data = datacube.data
    scan_i, scan_j, det_i, det_j = data.shape
    
    print(f"Reducing dose to {dose_fraction*100:.2f}% using {method}")
    
    # Select model
    if method == 'bimodal':
        # Bimodal: spike at zero + Gaussian signal
        sparsity = 1 - dose_fraction  # dose_fraction = fraction of non-zero pixels
        model = BimodalSparseNoise()
        params = {
            'sparsity_target': sparsity,
            'signal_mean': signal_mean,
            'signal_sigma': signal_sigma,
            'min_signal': 5.0
        }
    
    elif method == 'sparse_poisson':
        model = SparsePoissonNoise()
        params = {'dose_fraction': dose_fraction, 'ensure_sparse': True}
    
    elif method == 'electron_counting':
        # Calculate electrons per pattern from dose fraction
        avg_intensity = np.mean(data)
        electrons = avg_intensity * dose_fraction * det_i * det_j
        model = ElectronCountingNoise()
        params = {'electrons_per_pattern': electrons, 'dqe': 0.95}
    
    elif method == 'extreme_sparse':
        avg_intensity = np.mean(data)
        electrons = avg_intensity * dose_fraction * det_i * det_j * 0.1
        model = LowDoseSparseModel()
        params = {
            'sparsity_target': 0.98,
            'mean_electrons': electrons,
            'add_readout': False
        }
    
    else:
        raise ValueError(f"Unknown method: {method}")
    
    # Apply to all patterns
    low_dose_data = np.zeros_like(data)
    
    total = scan_i * scan_j
    for idx, (i, j) in enumerate(np.ndindex(scan_i, scan_j)):
        dp = data[i, j, :, :]
        low_dose_dp = model.apply(dp, **params)
        low_dose_data[i, j, :, :] = low_dose_dp
        
        if (idx + 1) % max(1, total // 10) == 0:
            print(f"  Progress: {idx+1}/{total}")
    
    # Create new datacube
    low_dose_cube = py4DSTEM.DataCube(data=low_dose_data)
    
    if hasattr(datacube, 'calibration'):
        low_dose_cube.calibration = datacube.calibration
    
    # Store metadata
    low_dose_cube.metadata['dose_reduction'] = {
        'method': method,
        'dose_fraction': dose_fraction,
        'original_mean': float(np.mean(data)),
        'reduced_mean': float(np.mean(low_dose_data)),
        'sparsity': float(np.sum(low_dose_data == 0) / low_dose_data.size)
    }
    
    print(f"Done! Sparsity: {low_dose_cube.metadata['dose_reduction']['sparsity']*100:.1f}% zeros")
    
    return low_dose_cube


# ============================================================================
# Analysis Functions
# ============================================================================

def analyze_sparsity(datacube: py4DSTEM.DataCube) -> Dict[str, float]:
    """
    Analyze sparsity characteristics of datacube
    
    Parameters
    ----------
    datacube : py4DSTEM.DataCube
        Input datacube
    
    Returns
    -------
    dict
        Sparsity statistics
    """
    data = datacube.data.astype(float)
    
    stats = {
        'sparsity': np.sum(data == 0) / data.size,
        'mean_intensity': np.mean(data),
        'median_intensity': np.median(data),
        'mean_nonzero': np.mean(data[data > 0]) if np.any(data > 0) else 0,
        'fraction_nonzero': np.sum(data > 0) / data.size,
        'max_intensity': np.max(data),
        'electrons_per_pattern': np.mean(np.sum(data, axis=(2, 3)))
    }
    
    print("="*60)
    print("SPARSITY ANALYSIS")
    print("="*60)
    print(f"Sparsity (zeros): {stats['sparsity']*100:.2f}%")
    print(f"Fraction nonzero: {stats['fraction_nonzero']*100:.2f}%")
    print(f"Mean intensity (all): {stats['mean_intensity']:.3f}")
    print(f"Mean intensity (nonzero): {stats['mean_nonzero']:.3f}")
    print(f"Median intensity: {stats['median_intensity']:.3f}")
    print(f"Max intensity: {stats['max_intensity']:.1f}")
    print(f"Avg electrons/pattern: {stats['electrons_per_pattern']:.1f}")
    print("="*60)
    
    return stats


def visualize_histogram(
    datacube: py4DSTEM.DataCube,
    log_scale: bool = True,
    bins: int = 100,
    title: str = "Intensity Histogram"
):
    """
    Visualize histogram of datacube intensities
    
    Useful for checking bimodal distribution:
    - Should see large spike at 0
    - Separate distribution for signal pixels
    
    Parameters
    ----------
    datacube : py4DSTEM.DataCube
        Input datacube
    log_scale : bool
        Use log scale for y-axis (recommended to see both peaks)
    bins : int
        Number of histogram bins
    title : str
        Plot title
    """
    import matplotlib.pyplot as plt
    
    data = datacube.data.flatten().astype(float)
    
    # Separate zeros and non-zeros for analysis
    n_zeros = np.sum(data == 0)
    n_nonzeros = np.sum(data > 0)
    nonzero_data = data[data > 0]
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Full histogram including zeros
    axes[0].hist(data, bins=bins, color='blue', alpha=0.7, edgecolor='black')
    if log_scale:
        axes[0].set_yscale('log')
    axes[0].set_xlabel('Intensity')
    axes[0].set_ylabel('Count (log scale)' if log_scale else 'Count')
    axes[0].set_title(f'{title}\n{n_zeros:,} zeros ({n_zeros/len(data)*100:.1f}%)')
    axes[0].grid(True, alpha=0.3)
    
    # Histogram of non-zero values only
    if len(nonzero_data) > 0:
        axes[1].hist(nonzero_data, bins=bins, color='green', alpha=0.7, edgecolor='black')
        axes[1].set_xlabel('Intensity')
        axes[1].set_ylabel('Count')
        axes[1].set_title(f'Non-Zero Values Only\n'
                         f'n={n_nonzeros:,}, mean={np.mean(nonzero_data):.1f}, '
                         f'std={np.std(nonzero_data):.1f}')
        axes[1].grid(True, alpha=0.3)
        
        # Add statistics
        axes[1].axvline(np.mean(nonzero_data), color='red', linestyle='--', 
                       linewidth=2, label=f'Mean: {np.mean(nonzero_data):.1f}')
        axes[1].legend()
    else:
        axes[1].text(0.5, 0.5, 'All zeros', ha='center', va='center',
                    transform=axes[1].transAxes, fontsize=16)
    
    plt.tight_layout()
    
    # Print statistics
    print("\n" + "="*60)
    print("HISTOGRAM STATISTICS")
    print("="*60)
    print(f"Total pixels: {len(data):,}")
    print(f"Zero pixels: {n_zeros:,} ({n_zeros/len(data)*100:.2f}%)")
    print(f"Non-zero pixels: {n_nonzeros:,} ({n_nonzeros/len(data)*100:.2f}%)")
    if len(nonzero_data) > 0:
        print(f"\nNon-zero statistics:")
        print(f"  Mean: {np.mean(nonzero_data):.2f}")
        print(f"  Std: {np.std(nonzero_data):.2f}")
        print(f"  Min: {np.min(nonzero_data):.2f}")
        print(f"  Max: {np.max(nonzero_data):.2f}")
        print(f"  Median: {np.median(nonzero_data):.2f}")
    print("="*60)
    
    return fig


def compare_histograms(
    original: py4DSTEM.DataCube,
    low_dose: py4DSTEM.DataCube,
    log_scale: bool = True
):
    """
    Compare histograms of original and low-dose datacubes
    
    Parameters
    ----------
    original : py4DSTEM.DataCube
        Original datacube
    low_dose : py4DSTEM.DataCube
        Low-dose datacube
    log_scale : bool
        Use log scale for y-axis
    """
    import matplotlib.pyplot as plt
    
    orig_data = original.data.flatten().astype(float)
    low_data = low_dose.data.flatten().astype(float)
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Original - full histogram
    axes[0, 0].hist(orig_data, bins=100, color='blue', alpha=0.7, edgecolor='black')
    if log_scale:
        axes[0, 0].set_yscale('log')
    axes[0, 0].set_xlabel('Intensity')
    axes[0, 0].set_ylabel('Count (log)' if log_scale else 'Count')
    axes[0, 0].set_title(f'Original\n{np.sum(orig_data==0)/len(orig_data)*100:.1f}% zeros')
    axes[0, 0].grid(True, alpha=0.3)
    
    # Original - non-zero only
    orig_nonzero = orig_data[orig_data > 0]
    if len(orig_nonzero) > 0:
        axes[0, 1].hist(orig_nonzero, bins=100, color='blue', alpha=0.7, edgecolor='black')
        axes[0, 1].set_xlabel('Intensity')
        axes[0, 1].set_ylabel('Count')
        axes[0, 1].set_title(f'Original (Non-Zero)\nmean={np.mean(orig_nonzero):.1f}')
        axes[0, 1].grid(True, alpha=0.3)
    
    # Low-dose - full histogram
    axes[1, 0].hist(low_data, bins=100, color='green', alpha=0.7, edgecolor='black')
    if log_scale:
        axes[1, 0].set_yscale('log')
    axes[1, 0].set_xlabel('Intensity')
    axes[1, 0].set_ylabel('Count (log)' if log_scale else 'Count')
    axes[1, 0].set_title(f'Low-Dose\n{np.sum(low_data==0)/len(low_data)*100:.1f}% zeros')
    axes[1, 0].grid(True, alpha=0.3)
    
    # Low-dose - non-zero only
    low_nonzero = low_data[low_data > 0]
    if len(low_nonzero) > 0:
        axes[1, 1].hist(low_nonzero, bins=100, color='green', alpha=0.7, edgecolor='black')
        axes[1, 1].set_xlabel('Intensity')
        axes[1, 1].set_ylabel('Count')
        axes[1, 1].set_title(f'Low-Dose (Non-Zero)\nmean={np.mean(low_nonzero):.1f}, '
                            f'std={np.std(low_nonzero):.1f}')
        axes[1, 1].axvline(np.mean(low_nonzero), color='red', linestyle='--',
                          linewidth=2, label=f'Mean')
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    return fig


def visualize_dose_comparison(
    original: py4DSTEM.DataCube,
    low_dose: py4DSTEM.DataCube,
    scan_pos: Tuple[int, int] = None,
    log_scale: bool = True
):
    """
    Visualize original vs low-dose patterns
    
    Parameters
    ----------
    original : py4DSTEM.DataCube
        Original datacube
    low_dose : py4DSTEM.DataCube
        Low-dose datacube
    scan_pos : tuple, optional
        (i, j) position to visualize
    log_scale : bool
        Use log scale for display
    """
    import matplotlib.pyplot as plt
    
    scan_i, scan_j = original.data.shape[:2]
    
    if scan_pos is None:
        scan_pos = (scan_i // 2, scan_j // 2)
    
    i, j = scan_pos
    
    orig_dp = original.data[i, j, :, :].astype(float)
    low_dp = low_dose.data[i, j, :, :].astype(float)
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Original
    if log_scale:
        im0 = axes[0].imshow(np.log10(orig_dp + 1), cmap='gray')
        axes[0].set_title(f'Original (log scale)\nScan ({i}, {j})')
    else:
        im0 = axes[0].imshow(orig_dp, cmap='gray', vmax=np.percentile(orig_dp, 99))
        axes[0].set_title(f'Original\nScan ({i}, {j})')
    plt.colorbar(im0, ax=axes[0])
    
    # Low dose
    if log_scale:
        im1 = axes[1].imshow(np.log10(low_dp + 1), cmap='gray')
        axes[1].set_title('Low Dose (log scale)')
    else:
        im1 = axes[1].imshow(low_dp, cmap='gray', vmax=np.percentile(low_dp, 99))
        axes[1].set_title('Low Dose')
    plt.colorbar(im1, ax=axes[1])
    
    # Difference
    diff = orig_dp - low_dp
    im2 = axes[2].imshow(diff, cmap='RdBu_r', 
                         vmin=-np.percentile(np.abs(diff), 99),
                         vmax=np.percentile(np.abs(diff), 99))
    axes[2].set_title('Difference (Original - Low Dose)')
    plt.colorbar(im2, ax=axes[2])
    
    plt.tight_layout()
    
    # Print statistics
    print(f"\nPattern statistics at position ({i}, {j}):")
    print(f"Original: sum={np.sum(orig_dp):.0f}, nonzero={np.sum(orig_dp>0)} pixels")
    print(f"Low dose: sum={np.sum(low_dp):.0f}, nonzero={np.sum(low_dp>0)} pixels")
    print(f"Sparsity: {np.sum(low_dp==0)/low_dp.size*100:.1f}% zeros")
    
    return fig


# ============================================================================
# Quick workflow function
# ============================================================================

def create_training_pair(
    datacube: py4DSTEM.DataCube,
    dose_fraction: float = 0.02,
    method: str = 'bimodal',
    signal_mean: float = 30.0,
    signal_sigma: float = 10.0,
    seed: Optional[int] = None
) -> Tuple[py4DSTEM.DataCube, py4DSTEM.DataCube]:
    """
    Create training pair for denoising: (clean, noisy)
    
    Parameters
    ----------
    datacube : py4DSTEM.DataCube
        Clean datacube (will be used as target)
    dose_fraction : float
        For 'bimodal': fraction of non-zero pixels (0.02 = 2% signal, 98% zeros)
        For other methods: dose fraction as before
    method : str
        Noise method to use (default: 'bimodal')
    signal_mean : float
        Mean value for signal pixels (bimodal method only)
    signal_sigma : float
        Std dev for signal pixels (bimodal method only)
    seed : int, optional
        Random seed for reproducibility
    
    Returns
    -------
    clean : py4DSTEM.DataCube
        Clean target datacube
    noisy : py4DSTEM.DataCube
        Low-dose noisy datacube
    
    Examples
    --------
    >>> # Bimodal (recommended): 98% zeros + Gaussian signal at mean=30
    >>> clean, noisy = create_training_pair(datacube, dose_fraction=0.02,
    ...                                      method='bimodal', signal_mean=30)
    
    >>> # Extreme sparse (old method)
    >>> clean, noisy = create_training_pair(datacube, dose_fraction=0.01,
    ...                                      method='extreme_sparse')
    """
    if seed is not None:
        np.random.seed(seed)
    
    print("Creating training pair for denoising...")
    print(f"Target: Clean datacube (original)")
    
    if method == 'bimodal':
        print(f"Input: Bimodal sparse datacube ({(1-dose_fraction)*100:.1f}% zeros, "
              f"signal mean={signal_mean})")
    else:
        print(f"Input: Low-dose datacube ({dose_fraction*100:.2f}% dose)")
    
    noisy = reduce_dose(
        datacube,
        dose_fraction=dose_fraction,
        method=method,
        signal_mean=signal_mean,
        signal_sigma=signal_sigma
    )
    
    print("\nAnalyzing clean datacube:")
    clean_stats = analyze_sparsity(datacube)
    
    print("\nAnalyzing noisy datacube:")
    noisy_stats = analyze_sparsity(noisy)
    
    print(f"\nDose reduction: {clean_stats['mean_intensity']/noisy_stats['mean_intensity']:.1f}x")
    
    return datacube, noisy