In [14]:
import os
import h5py
import numpy as np
import math

from scipy.ndimage import median_filter, gaussian_filter, zoom
from skimage import exposure
from scipy.signal import fftconvolve
from skimage.feature import peak_local_max
import matplotlib.pyplot as plt
from helpers import sech, normalize, ipf3D, convolution3D_FFTdomain, chiimg3D_FFT, findpeaks3D, pgrid3D, cidp23D

In [15]:
filePath = r"C:\Users\Lab User\Desktop\experiment data\07292025\Scan_4.hdf5"
dataPath = f'/RawData/Scan_4'

with h5py.File(filePath,'r') as f: 
    data = f[dataPath][()]
data = np.transpose(data, (1,2,0))  # (y,x,z)

# Crop ROI
yMin, yMax, xMin, xMax = 160, 1080, 66, 956
dataCropped = data[yMin:yMax, xMin:xMax, :].astype(float)

cropX, cropY, cropZ = np.shape(dataCropped)

In [16]:
highPct, medPct = 99, 95

# Apply median filter and threshold
threshHigh = np.percentile(dataCropped, highPct)
filteredData = median_filter(dataCropped, size=(3, 3, 3))
dataCropped[dataCropped > threshHigh] = filteredData[dataCropped > threshHigh]

# Update threshold and replace values above it with the median
threshHigh = np.percentile(dataCropped, highPct)
dataCropped[dataCropped > threshHigh] = np.mean(np.percentile(dataCropped, medPct))

# Normalize data to range [0, 1] and clip it
dataNorm = np.clip(exposure.rescale_intensity(dataCropped, in_range='image', out_range=(0, 1)), 0, 1)

# Column-wise normalization
meanPerCol = np.mean(dataNorm, axis=(0, 2))  # Mean across 2nd and 3rd dimensions (rows and depth)
scaleFactors = meanPerCol[int(cropZ/2)] / meanPerCol  # Normalize with the middle column
dataColNorm = np.clip(dataNorm * scaleFactors[:, np.newaxis, np.newaxis], 0, 1)

# Re-normalize after column normalization
dataReNorm = exposure.rescale_intensity(dataColNorm, in_range='image', out_range=(0, 1))

# Sharpen image
strength = 0.2
blurred = gaussian_filter(dataColNorm, sigma=3)
dataSharp = np.clip(dataReNorm - strength * (dataReNorm - blurred), 0, 1)

ValueError: operands could not be broadcast together with shapes (920,890,852) (890,1,1) 

In [None]:
# Downsample
scaleFactor = 0.5
dataRescale = zoom(dataSharp, scaleFactor, order=1)

In [None]:
# -----------------------------
# Plot intermediate steps
# -----------------------------
z_mid = dataNorm.shape[2] // 2

plt.figure(figsize=(12,4))
plt.subplot(1,3,1)
plt.imshow(dataCropped[:,:,z_mid], cmap='gray')
plt.title('Cropped')
plt.subplot(1,3,2)
plt.imshow(dataNorm[:,:,z_mid], cmap='gray')
plt.title('Normalized + Background Corrected')
plt.subplot(1,3,3)
plt.imshow(dataSharp[:,:,z_mid], cmap='gray')
plt.title('Sharpened')
plt.show()

plt.figure()
plt.hist(dataNorm.ravel(), bins=100)
plt.xlabel('Voxel intensity')
plt.ylabel('Number of voxels')
plt.title('Histogram of normalized voxel values')
plt.show()

In [None]:
# -----------------------------
# Particle detection
# -----------------------------
D, w = 100*scaleFactor, 2.5*scaleFactor
Cutoff, MinSep = 5, 75*scaleFactor
ss = int(2*(D/2 + 4*w/2)-1)
os = (ss-1)//2
xx, yy, zz = np.meshgrid(np.arange(-os,os+1), np.arange(-os,os+1), np.arange(-os,os+1), indexing='ij')
r = np.sqrt(xx**2 + yy**2 + zz**2)
ipi = ipf3D(r, D, w)
chi3D, _ = chiimg3D_FFT(dataRescale, ipi)
Np, px, py, pz = findpeaks3D(1./(chi3D+1e-12), Cutoff=Cutoff, MinSep=MinSep)

In [None]:
# Sub-voxel refinement
cxyz, over = pgrid3D(px, py, pz, dataRescale.shape[1], dataRescale.shape[0], dataRescale.shape[2], Np, os, 0)
r_full = np.sqrt(cxyz.x**2 + cxyz.y**2 + cxyz.z**2)
ci = ipf3D(r_full, D, w)
di = ci - dataRescale
chi2 = np.sum(di**2)
nr, delchi2, mindelchi2, maxnr = 0, 1e99, 1, 5
while abs(delchi2)>mindelchi2 and nr<maxnr:
    dpx, dpy, dpz = cidp23D(cxyz, over, di, Np, D, w)
    px += dpx; py += dpy; pz += dpz
    cxyz, over = pgrid3D(px, py, pz, dataRescale.shape[1], dataRescale.shape[0], dataRescale.shape[2], Np, os, 0)
    r_full = np.sqrt(cxyz.x**2 + cxyz.y**2 + cxyz.z**2)
    ci = ipf3D(r_full, D, w)
    di = ci - dataRescale
    delchi2 = chi2 - np.sum(di**2)
    chi2 -= delchi2
    nr += 1

In [None]:
# -----------------------------
# Visualization
# -----------------------------
# Max projection of residual
plt.figure()
plt.imshow(np.max(di**2, axis=2), cmap='gray', origin='lower')
plt.colorbar()
plt.title(f'Residual Chi^2 (max projection), Chi2={chi2:.2f}')
plt.show()

In [None]:
# 3D scatter of particle centers
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(px, py, pz, c='r', s=50)
ax.set_xlim(0, dataRescale.shape[1])
ax.set_ylim(0, dataRescale.shape[0])
ax.set_zlim(0, dataRescale.shape[2])
ax.set_xlabel('X'); ax.set_ylabel('Y'); ax.set_zlabel('Z')
ax.set_title('Detected particle centers')
plt.show()
