In [None]:
def psnr_ssim_maps_per_scan(
    original_dc,
    noisy_dc,
    psnr_max_mode: str = "global_orig_max",   # "global_orig_max" | "per_frame_max" | "global_data_range" | "per_frame_data_range"
    ssim_range_mode: str = "global_data_range" # "global_data_range" | "per_frame_data_range"
):
    orig = original_dc.data
    noisy = noisy_dc.data

    if orig.shape != noisy.shape:
        raise ValueError(f"Shape mismatch: {orig.shape} vs {noisy.shape}")

    scan_i, scan_j, det_i, det_j = orig.shape
    psnr_map = np.empty((scan_i, scan_j), dtype=np.float32)
    ssim_map = np.empty((scan_i, scan_j), dtype=np.float32)

    # global stats (no big copy)
    global_min = float(np.min(orig))
    global_max = float(np.max(orig))
    global_range = global_max - global_min
    if global_range == 0:
        global_range = 1.0

    for i in range(scan_i):
        for j in range(scan_j):
            o = np.asarray(orig[i, j], dtype=np.float32)   # 2D only
            n = np.asarray(noisy[i, j], dtype=np.float32)

            d = n - o
            ss = float(np.sum(d * d, dtype=np.float64))
            mse = ss / d.size
            rmse = float(np.sqrt(mse))

            # ---- PSNR MAX choice ----
            if psnr_max_mode == "per_frame_max":
                MAX = float(np.max(o))
                if MAX == 0:
                    MAX = global_max if global_max != 0 else 1.0
            elif psnr_max_mode == "global_data_range":
                MAX = float(global_max - global_min)
                if MAX == 0:
                    MAX = 1.0
            elif psnr_max_mode == "per_frame_data_range":
                MAX = float(np.max(o) - np.min(o))
                if MAX == 0:
                    MAX = float(global_max - global_min) if (global_max - global_min) != 0 else 1.0
            else:  # "global_orig_max"
                MAX = global_max if global_max != 0 else 1.0

            psnr_map[i, j] = np.inf if rmse == 0 else (20.0 * np.log10(MAX / rmse))

            # ---- SSIM data_range choice ----
            if ssim_range_mode == "per_frame_data_range":
                dr = float(np.max(o) - np.min(o))
                if dr == 0:
                    dr = global_range
            else:
                dr = global_range

            ssim_map[i, j] = float(ssim(o, n, data_range=dr))

    return psnr_map, ssim_map

In [None]:
psnr_map, ssim_map = psnr_ssim_maps_per_scan(
    datacube, noisy_dc,
    psnr_max_mode="global_orig_max",
    ssim_range_mode="global_data_range"
)

print("Mean PSNR:", np.mean(psnr_map), "±", np.std(psnr_map))  
print("Mean SSIM:", np.mean(ssim_map), "±", np.std(ssim_map))