Peter Second day

RAM safe way to load the datacube and compute noisecube

In [None]:
# ============================================================================
# Noise models + memory-safe add_noise_to_datacube (drop-in replacement)
# ============================================================================

from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, List, Tuple

import numpy as np
import py4DSTEM


# ============================================================================
# Abstract base class for noise models
# ============================================================================
class NoiseModel(ABC):
    """Base class for noise models"""

    @abstractmethod
    def apply(self, data: np.ndarray, rng: Optional[np.random.Generator] = None, **kwargs) -> np.ndarray:
        """Apply noise to the data"""
        ...


# ============================================================================
# Concrete noise model implementations
# ============================================================================
class PoissonNoise(NoiseModel):
    """Shot noise following Poisson statistics"""

    def apply(self, data: np.ndarray, scale: float = 1.0,
              rng: Optional[np.random.Generator] = None) -> np.ndarray:
        rng = np.random.default_rng() if rng is None else rng

        # Poisson requires non-negative rate
        x = np.asarray(data, dtype=np.float32)
        x = np.maximum(x, 0.0)

        lam = x * float(scale)
        noisy = rng.poisson(lam=lam).astype(np.float32, copy=False)

        inv = 1.0 / float(scale) if scale != 0 else 1.0
        return noisy * inv


class GaussianNoise(NoiseModel):
    """Additive Gaussian (normal) noise"""

    def apply(self, data: np.ndarray, mean: float = 0.0, sigma: float = 1.0,
              rng: Optional[np.random.Generator] = None) -> np.ndarray:
        rng = np.random.default_rng() if rng is None else rng
        x = np.asarray(data, dtype=np.float32)
        noiseuree = rng.normal(loc=float(mean), scale=float(sigma), size=x.shape).astype(np.float32, copy=False)
        return x + noiseuree


class ReadoutNoise(NoiseModel):
    """Readout noise (additive Gaussian, in counts)."""

    def apply(self, data: np.ndarray, sigma: float = 5.0,
              rng: Optional[np.random.Generator] = None) -> np.ndarray:
        rng = np.random.default_rng() if rng is None else rng
        x = np.asarray(data, dtype=np.float32)
        noise = rng.normal(loc=0.0, scale=float(sigma), size=x.shape).astype(np.float32, copy=False)
        return x + noise


class DarkCurrentNoise(NoiseModel):
    """Dark current (Poisson-distributed counts)."""

    def apply(self, data: np.ndarray, dark_current: float = 1.0,
              rng: Optional[np.random.Generator] = None) -> np.ndarray:
        rng = np.random.default_rng() if rng is None else rng
        x = np.asarray(data, dtype=np.float32)
        lam = max(0.0, float(dark_current))
        dark = rng.poisson(lam=lam, size=x.shape).astype(np.float32, copy=False)
        return x + dark


class SaltPepperNoise(NoiseModel):
    """Salt and pepper (impulse) noise"""

    def apply(self, data: np.ndarray, probability: float = 0.01,
              salt_value: Optional[float] = None,
              pepper_value: float = 0.0,
              rng: Optional[np.random.Generator] = None) -> np.ndarray:
        rng = np.random.default_rng() if rng is None else rng
        x = np.asarray(data, dtype=np.float32)

        noisy = x.copy()
        if salt_value is None:
            salt_value = float(np.max(x))

        p = float(probability)
        if p <= 0:
            return noisy

        salt_mask = rng.random(x.shape) < (p / 2.0)
        noisy[salt_mask] = float(salt_value)

        pepper_mask = rng.random(x.shape) < (p / 2.0)
        noisy[pepper_mask] = float(pepper_value)

        return noisy


class CorrelatedNoise(NoiseModel):
    """Spatially correlated noise (low-frequency)"""

    def apply(self, data: np.ndarray, sigma: float = 1.0, correlation_length: float = 5.0,
              rng: Optional[np.random.Generator] = None) -> np.ndarray:
        from scipy.ndimage import gaussian_filter

        rng = np.random.default_rng() if rng is None else rng
        x = np.asarray(data, dtype=np.float32)

        white = rng.normal(loc=0.0, scale=float(sigma), size=x.shape).astype(np.float32, copy=False)
        corr = gaussian_filter(white, sigma=float(correlation_length))
        return x + corr


class DrizzleNearBrightPoissonNoise(NoiseModel):
    """
    Drizzle Poisson counts around the brightest pixels (correlated salt-like noise).
    (Kept compatible with your original implementation.)
    """

    def apply(
        self,
        data: np.ndarray,
        bright_fraction: float = 0.01,
        radius_px: int = 5,
        square_side: int = 10,
        drizzles_per_seed: int = 3,
        lam_fraction: float = 0.05,
        lam_min: float = 1.0,
        exclude_center: bool = True,
        rng: Optional[np.random.Generator] = None,
    ) -> np.ndarray:
        rng = np.random.default_rng() if rng is None else rng

        img = np.asarray(data, dtype=np.float32, order="C")
        h, w = img.shape

        if not (0.0 < bright_fraction < 1.0):
            raise ValueError("bright_fraction must be in (0,1)")

        thr = np.quantile(img, 1.0 - bright_fraction)
        seeds = np.argwhere(img >= thr)
        if seeds.size == 0:
            return img.copy()

        half = square_side // 2
        ys = np.arange(-half, -half + square_side, dtype=int)
        xs = np.arange(-half, -half + square_side, dtype=int)
        dy, dx = np.meshgrid(ys, xs, indexing="ij")
        dy = dy.ravel()
        dx = dx.ravel()

        mask_r = (dy * dy + dx * dx) <= (radius_px * radius_px)
        if exclude_center:
            mask_r &= ~((dy == 0) & (dx == 0))

        dy = dy[mask_r]
        dx = dx[mask_r]
        n_candidates = dy.size
        if n_candidates == 0:
            return img.copy()

        out = img.copy()

        for (y, x) in seeds:
            seed_intensity = float(img[y, x])
            lam = max(float(lam_min), float(lam_fraction) * seed_intensity)

            k = min(int(drizzles_per_seed), n_candidates)
            pick = rng.choice(n_candidates, size=k, replace=False)

            yy = y + dy[pick]
            xx = x + dx[pick]

            inb = (yy >= 0) & (yy < h) & (xx >= 0) & (xx < w)
            yy = yy[inb]
            xx = xx[inb]
            if yy.size == 0:
                continue

            drizzle = rng.poisson(lam=lam, size=yy.size).astype(np.float32, copy=False)
            out[yy, xx] += drizzle

        return out


# ============================================================================
# Main function to add noise to datacube (memory-safe)
# ============================================================================
def add_noise_to_datacube(
    datacube: py4DSTEM.DataCube,
    noise_models: List[Tuple[NoiseModel, Dict[str, Any]]],
    seed: Optional[int] = None,
    clip_negative: bool = True,
    preserve_dtype: bool = True,
) -> py4DSTEM.DataCube:
    """
    Memory-safe: process each diffraction pattern (i,j) frame-by-frame.
    Avoids copying the whole 4D cube to float64 (which caused your 32 GiB MemoryError).
    """
    rng = np.random.default_rng(seed)

    data = datacube.data
    original_dtype = data.dtype
    scan_i, scan_j, det_i, det_j = data.shape

    print(f"Adding noise to datacube of shape {data.shape}")

    out_dtype = original_dtype if preserve_dtype else np.float32
    noisy_data = np.empty_like(data, dtype=out_dtype)

    for i in range(scan_i):
        for j in range(scan_j):
            dp = np.asarray(data[i, j], dtype=np.float32)  # small 2D buffer only

            noisy_dp = dp
            for noise_model, params in noise_models:
                noisy_dp = noise_model.apply(noisy_dp, rng=rng, **params)

            if clip_negative:
                noisy_dp = np.maximum(noisy_dp, 0.0)

            if preserve_dtype:
                if np.issubdtype(original_dtype, np.integer):
                    dtype_max = np.iinfo(original_dtype).max
                    noisy_dp = np.clip(noisy_dp, 0, dtype_max)
                noisy_data[i, j] = noisy_dp.astype(original_dtype, copy=False)
            else:
                noisy_data[i, j] = noisy_dp.astype(np.float32, copy=False)

    noisy_datacube = py4DSTEM.DataCube(data=noisy_data)

    if hasattr(datacube, "calibration"):
        noisy_datacube.calibration = datacube.calibration

    if hasattr(noisy_datacube, "metadata"):
        noisy_datacube.metadata["noise_applied"] = [
            {"model": model.__class__.__name__, "parameters": params}
            for model, params in noise_models
        ]

    print("Noise addition complete!")
    return noisy_datacube


# ============================================================================
# Convenience function for common noise combinations
# ============================================================================
def add_realistic_detector_noise(
    datacube: py4DSTEM.DataCube,
    dose_scale: float = 100,
    readout_sigma: float = 5.0,
    dark_current: float = 1.0,
    seed: Optional[int] = None,
) -> py4DSTEM.DataCube:
    noise_models = [
        (PoissonNoise(), {"scale": dose_scale}),
        (DarkCurrentNoise(), {"dark_current": dark_current}),
        (ReadoutNoise(), {"sigma": readout_sigma}),
    ]
    return add_noise_to_datacube(datacube, noise_models, seed=seed)


In [None]:
# Example 3: Realistic detector noise
noisy_dc3 = add_realistic_detector_noise(
    datacube,
    dose_scale=80,
    readout_sigma=3.5,
    dark_current=0.5
)


Noise Metrics: SSIM & PSNR

In [None]:
import numpy as np
from skimage.metrics import structural_similarity as ssim

def psnr_ssim_maps_per_scan(
    original_dc,
    noisy_dc,
    psnr_max_mode: str = "global_orig_max",   # "global_orig_max" | "per_frame_max" | "global_data_range" | "per_frame_data_range"
    ssim_range_mode: str = "global_data_range" # "global_data_range" | "per_frame_data_range"
):
    orig = original_dc.data
    noisy = noisy_dc.data

    if orig.shape != noisy.shape:
        raise ValueError(f"Shape mismatch: {orig.shape} vs {noisy.shape}")

    scan_i, scan_j, det_i, det_j = orig.shape
    psnr_map = np.empty((scan_i, scan_j), dtype=np.float32)
    ssim_map = np.empty((scan_i, scan_j), dtype=np.float32)

    # global stats (no big copy)
    global_min = float(np.min(orig))
    global_max = float(np.max(orig))
    global_range = global_max - global_min
    if global_range == 0:
        global_range = 1.0

    for i in range(scan_i):
        for j in range(scan_j):
            o = np.asarray(orig[i, j], dtype=np.float32)   # 2D only
            n = np.asarray(noisy[i, j], dtype=np.float32)

            d = n - o
            ss = float(np.sum(d * d, dtype=np.float64))
            mse = ss / d.size
            rmse = float(np.sqrt(mse))

            # ---- PSNR MAX choice ----
            if psnr_max_mode == "per_frame_max":
                MAX = float(np.max(o))
                if MAX == 0:
                    MAX = global_max if global_max != 0 else 1.0
            elif psnr_max_mode == "global_data_range":
                MAX = float(global_max - global_min)
                if MAX == 0:
                    MAX = 1.0
            elif psnr_max_mode == "per_frame_data_range":
                MAX = float(np.max(o) - np.min(o))
                if MAX == 0:
                    MAX = float(global_max - global_min) if (global_max - global_min) != 0 else 1.0
            else:  # "global_orig_max"
                MAX = global_max if global_max != 0 else 1.0

            psnr_map[i, j] = np.inf if rmse == 0 else (20.0 * np.log10(MAX / rmse))

            # ---- SSIM data_range choice ----
            if ssim_range_mode == "per_frame_data_range":
                dr = float(np.max(o) - np.min(o))
                if dr == 0:
                    dr = global_range
            else:
                dr = global_range

            ssim_map[i, j] = float(ssim(o, n, data_range=dr))

    return psnr_map, ssim_map

# ---- run ----
psnr_map, ssim_map = psnr_ssim_maps_per_scan(
    datacube, noisy_dc3,
    psnr_max_mode="global_orig_max",
    ssim_range_mode="global_data_range"
)

print("PSNR map shape:", psnr_map.shape)  # expect (255,255)
print("SSIM map shape:", ssim_map.shape)

# optional save
# np.save("psnr_map.npy", psnr_map)
# np.save("ssim_map._


In [None]:
plt.figure(figsize=(6, 6))
im = plt.imshow(ssim_map, vmin=0.7, vmax=0.8)  # SSIM typically in [0,1]
plt.title("SSIM per SAED (scan_i × scan_j)")
plt.xlabel("scan j")
plt.ylabel("scan i")
plt.colorbar(im, label="SSIM")
plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(6, 6))
im = plt.imshow(psnr_map, vmin=39, vmax=42)  # SSIM typically in [0,1]
plt.title("PSNR per SAED (scan_i × scan_j)")
plt.xlabel("scan j")
plt.ylabel("scan i")
plt.colorbar(im, label="SSIM")
plt.tight_layout()
plt.show()