In [None]:
# =========================
# Paths (EDIT ONLY HERE)
# =========================
RAW_H5_PATH   = r"C:\Users\dji\Downloads\20220509_182024_data_binned2.hdf5"    # <- your original/raw 4D-STEM file
NOISY_H5_PATH = r"C:\Users\dji\Downloads\anthracene_2_low_dose_bimodal.hdf5"          # <- your saved noisy datacube file
# ---- load raw and noisy into datacubes (no imports needed if Cell 1 ran) ----
import hyperspy.api as hs
import py4DSTEM

# raw
s_raw = hs.load(RAW_H5_PATH, reader="HSPY")
dc_raw = py4DSTEM.DataCube(s_raw.data)

# noisy (same format assumption; if this fails, we switch to h5py-based load)
s_noisy = hs.load(NOISY_H5_PATH, reader="HSPY")
dc_noisy = py4DSTEM.DataCube(s_noisy.data)

# optional: keep old variable names for compatibility
datacube = dc_raw
noisy_dc3 = dc_noisy

# ---- metric function (uses np and ssim from Cell 1) ----
def psnr_ssim_maps_per_scan(
    original_dc,
    noisy_dc,
    psnr_max_mode: str = "global_orig_max",
    ssim_range_mode: str = "global_data_range",
):
    orig = original_dc.data
    noisy = noisy_dc.data

    if orig.shape != noisy.shape:
        raise ValueError(f"Shape mismatch: {orig.shape} vs {noisy.shape}")

    scan_i, scan_j, _, _ = orig.shape
    psnr_map = np.empty((scan_i, scan_j), dtype=np.float32)
    ssim_map = np.empty((scan_i, scan_j), dtype=np.float32)

    global_min = float(np.min(orig))
    global_max = float(np.max(orig))
    global_range = global_max - global_min
    if global_range == 0:
        global_range = 1.0
    if global_max == 0:
        global_max = 1.0

    for i in range(scan_i):
        for j in range(scan_j):
            o = np.asarray(orig[i, j], dtype=np.float32)
            n = np.asarray(noisy[i, j], dtype=np.float32)

            d = n - o
            ss = float(np.sum(d * d, dtype=np.float64))
            mse = ss / d.size
            rmse = float(np.sqrt(mse))

            if psnr_max_mode == "per_frame_max":
                MAX = float(np.max(o)) or global_max
            elif psnr_max_mode == "global_data_range":
                MAX = global_range
            elif psnr_max_mode == "per_frame_data_range":
                MAX = float(np.max(o) - np.min(o)) or global_range
            else:
                MAX = global_max

            psnr_map[i, j] = np.inf if rmse == 0 else float(20.0 * np.log10(MAX / rmse))

            if ssim_range_mode == "per_frame_data_range":
                dr = float(np.max(o) - np.min(o)) or global_range
            else:
                dr = global_range

            ssim_map[i, j] = float(ssim(o, n, data_range=dr))

    return psnr_map, ssim_map


# ---- run ----
psnr_map, ssim_map = psnr_ssim_maps_per_scan(
    datacube, noisy_dc3,
    psnr_max_mode="global_orig_max",
    ssim_range_mode="global_data_range"
)

print("Raw shape:", datacube.data.shape)
print("Noisy shape:", noisy_dc3.data.shape)
print("PSNR map shape:", psnr_map.shape)
print("SSIM map shape:", ssim_map.shape)


In [None]:
import numpy as np
from skimage.metrics import structural_similarity as ssim

def psnr_ssim_maps_per_scan(
    original_dc,
    noisy_dc,
    psnr_max_mode: str = "global_orig_max",   # "global_orig_max" | "per_frame_max" | "global_data_range" | "per_frame_data_range"
    ssim_range_mode: str = "global_data_range" # "global_data_range" | "per_frame_data_range"
):
    orig = original_dc.data
    noisy = noisy_dc.data

    if orig.shape != noisy.shape:
        raise ValueError(f"Shape mismatch: {orig.shape} vs {noisy.shape}")

    scan_i, scan_j, det_i, det_j = orig.shape
    psnr_map = np.empty((scan_i, scan_j), dtype=np.float32)
    ssim_map = np.empty((scan_i, scan_j), dtype=np.float32)

    # global stats (no big copy)
    global_min = float(np.min(orig))
    global_max = float(np.max(orig))
    global_range = global_max - global_min
    if global_range == 0:
        global_range = 1.0

    for i in range(scan_i):
        for j in range(scan_j):
            o = np.asarray(orig[i, j], dtype=np.float32)   # 2D only
            n = np.asarray(noisy[i, j], dtype=np.float32)

            d = n - o
            ss = float(np.sum(d * d, dtype=np.float64))
            mse = ss / d.size
            rmse = float(np.sqrt(mse))

            # ---- PSNR MAX choice ----
            if psnr_max_mode == "per_frame_max":
                MAX = float(np.max(o))
                if MAX == 0:
                    MAX = global_max if global_max != 0 else 1.0
            elif psnr_max_mode == "global_data_range":
                MAX = float(global_max - global_min)
                if MAX == 0:
                    MAX = 1.0
            elif psnr_max_mode == "per_frame_data_range":
                MAX = float(np.max(o) - np.min(o))
                if MAX == 0:
                    MAX = float(global_max - global_min) if (global_max - global_min) != 0 else 1.0
            else:  # "global_orig_max"
                MAX = global_max if global_max != 0 else 1.0

            psnr_map[i, j] = np.inf if rmse == 0 else (20.0 * np.log10(MAX / rmse))

            # ---- SSIM data_range choice ----
            if ssim_range_mode == "per_frame_data_range":
                dr = float(np.max(o) - np.min(o))
                if dr == 0:
                    dr = global_range
            else:
                dr = global_range

            ssim_map[i, j] = float(ssim(o, n, data_range=dr))

    return psnr_map, ssim_map

# ---- run ----
psnr_map, ssim_map = psnr_ssim_maps_per_scan(
    datacube, noisy_dc3,
    psnr_max_mode="global_orig_max",
    ssim_range_mode="global_data_range"
)

print("PSNR map shape:", psnr_map.shape)  # expect (255,255)
print("SSIM map shape:", ssim_map.shape)

# optional save
# np.save("psnr_map.npy", psnr_map)
# np.save("ssim_map._


In [None]:
plt.figure(figsize=(6, 6))
im = plt.imshow(ssim_map, vmin=0.75, vmax=0.85)  # SSIM typically in [0,1]
plt.title("SSIM per SAED (scan_i × scan_j)")
plt.xlabel("scan j")
plt.ylabel("scan i")
plt.colorbar(im, label="SSIM")
plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(6, 6))
im = plt.imshow(psnr_map, vmin=30, vmax=37)  # adjust as you like
plt.title("PSNR per SAED (scan_i × scan_j)")
plt.xlabel("scan j")
plt.ylabel("scan i")
plt.colorbar(im, label="PSNR (dB)")  # <-- correct label
plt.tight_layout()
plt.show()


In [None]:
ssim_255 = ssim_map   # <-- replace ssim_map with your real variable name
psnr_255 = psnr_map   # same idea
import os

In [None]:
# Choose output folder
# -------------------------
out_dir = r"C:\Users\dji\Downloads"
os.makedirs(out_dir, exist_ok=True)

# -------------------------
# 1) Save SSIM / PSNR maps
# -------------------------
# If your PSNR map variable is called psnr_map, rename it here:
psnr_255 = psnr_map  # <-- change this if your variable name differs

np.save(os.path.join(out_dir, "ssim_map_255x255.npy"), ssim_255.astype(np.float32))
np.save(os.path.join(out_dir, "psnr_map_255x255.npy"), psnr_255.astype(np.float32))

print("Saved metrics to:", out_dir)

# -------------------------
# 2) Save noisy datacube
# -------------------------
# Save as HyperSpy HDF5 (works well + widely readable)
# noisy_dc3.data is (scan_i, scan_j, det_i, det_j)
sig = hs.signals.Signal2D(noisy_dc3.data)  # HyperSpy can store multidimensional data
sig.save(os.path.join(out_dir, "noisy_dc3.h5"), overwrite=True)

print("Saved noisy datacube to:", os.path.join(out_dir, "noisy_dc3.h5"))
