In [None]:
import os
import h5py
import numpy as np
from scipy.ndimage import median_filter, gaussian_filter
from skimage.exposure import rescale_intensity
from skimage.transform import resize
import matplotlib.pyplot as plt

# -------------------------------------------------
# Load file names
# -------------------------------------------------
expNum = "07292025"
scanName = "Scan_4"
folderPath = r"Users\Lab User\Desktop\experiment data"
filePath = fr"C:\{folderPath}\{expNum}\{scanName}.hdf5"
dataFolder = "RawData"
dataPath = f"/{dataFolder}/{scanName}"

# -------------------------------------------------
# Load & crop data
# -------------------------------------------------
with h5py.File(filePath, "r") as f:
    data = f[dataPath][()]
data = np.transpose(data, (1, 0, 2))  # (y,x,z)
print(f"Original: {data.shape} [y,x,z]")

# Crop ROI
yMin, yMax, xMin, xMax = 160, 1080, 66, 956
dataCropped = data[yMin:yMax+1, xMin:xMax+1, :].astype(np.float64)
yDim, xDim, zDim = dataCropped.shape
print(f"Cropped: {dataCropped.shape}")

# -------------------------------------------------
# Find extremes & Median filter salt
# -------------------------------------------------
highPct, medPct = 99, 95
threshHigh = np.percentile(dataCropped, highPct)
filteredData = median_filter(dataCropped, size=3)
mask = dataCropped > threshHigh
dataCropped[mask] = filteredData[mask]
threshHigh = np.percentile(dataCropped, highPct)
meanMed = np.mean(np.percentile(dataCropped, medPct))
dataCropped[dataCropped > threshHigh] = meanMed
print("Applied median filter")

# -------------------------------------------------
# Normalize
# -------------------------------------------------
dataNorm = rescale_intensity(dataCropped, out_range=(0, 1))
dataNorm = np.clip(dataNorm, 0, 1)

# Column-wise normalization
meanPerCol = dataNorm.mean(axis=(1, 2))
scaleFactors = meanPerCol[yDim // 2] / meanPerCol
dataColNorm = dataNorm * scaleFactors[:, None, None]
dataColNorm = np.clip(dataColNorm, 0, 1)
dataReNorm = rescale_intensity(dataColNorm, out_range=(0, 1))

# -------------------------------------------------
# Sharpen image
# -------------------------------------------------
strength = 0.2
blurred = gaussian_filter(dataColNorm, sigma=3)
dataSharp = dataReNorm - strength * (dataReNorm - blurred)
dataSharp = np.clip(dataSharp, 0, 1)

# -------------------------------------------------
# Show steps
# -------------------------------------------------
plt.figure()
plt.hist(dataNorm.ravel(), bins=100)
plt.xlabel("Voxel intensity")
plt.ylabel("Number of voxels")
plt.title("Histogram of preprocessed voxel values")
plt.grid(True)
plt.show()

midz = zDim // 2
fig, axs = plt.subplots(2, 3, figsize=(12, 8))
axs[0, 0].imshow(dataCropped[:, :, midz], cmap="gray"); axs[0, 0].set_title("Crop")
axs[0, 1].imshow(dataNorm[:, :, midz], cmap="gray"); axs[0, 1].set_title("Norm")
axs[0, 2].imshow(dataColNorm[:, :, midz], cmap="gray"); axs[0, 2].set_title("ColNorm")
axs[1, 0].imshow(dataReNorm[:, :, midz], cmap="gray"); axs[1, 0].set_title("Renorm")
axs[1, 1].imshow(dataSharp[:, :, midz], cmap="gray"); axs[1, 1].set_title("Sharp")
plt.show()

# -------------------------------------------------
# Downsample for speed
# -------------------------------------------------
scaleFactor = 0.5
new_shape = np.round(np.array(dataSharp.shape) * scaleFactor).astype(int)
dataRescale = resize(dataSharp, output_shape=new_shape,
                     order=1, preserve_range=True, anti_aliasing=True)
Ny, Nx, Nz = dataRescale.shape
print(f"Rescaled: {dataRescale.shape}")

# -------------------------------------------------
# Particle tracking parameters
# -------------------------------------------------
D = 100 * scaleFactor         # initial particle diameter
w = 2.5 * scaleFactor         # initial width parameter
Cutoff = 5
MinSep = 75 * scaleFactor
eps = 1e-12

# Setup ideal particle grid
ss = int(2 * (D // 2 + 2 * (w // 2)) - 1)
os = (ss - 1) // 2
xx, yy, zz = np.meshgrid(np.arange(-os, os+1),
                         np.arange(-os, os+1),
                         np.arange(-os, os+1), indexing="ij")
r = np.sqrt(xx**2 + yy**2 + zz**2)

# -------------------------------------------------
# Custom particle-tracking functions
# -------------------------------------------------
ipi = ipf3D(r, D, w)
chi3D, Wip2 = chiimg3D_FFT(dataRescale, ipi)
Np, px, py, pz = findpeaks3D(1.0 / (chi3D + eps), np.ones_like(chi3D), Cutoff, MinSep)
print(f"Found {Np} initial peaks")
szIpi = ipi.shape

plt.figure()
plt.imshow(1.0 / (chi3D[:, :, Nz // 2] + eps), cmap="gray")
plt.title("Chi3D")
plt.plot(py, px, "r.")
plt.show()

print("      Temp | Data")
print(f"Min: ({ipi.min():.2f}, {dataRescale.min():.2f})")
print(f"Max: ({ipi.max():.2f}, {dataRescale.max():.2f})")
print(f"Mean:({ipi.mean():.2f}, {dataRescale.mean():.2f})")

# -------------------------------------------------
# Sub-voxel chi-squared minimization
# -------------------------------------------------
cxyz, over = pgrid3D(px, py, pz, Nx, Ny, Nz, Np, os, 0)
r_full = np.sqrt(cxyz.x**2 + cxyz.y**2 + cxyz.z**2)
r_full = np.transpose(r_full, (1, 0, 2))  # match dataRescale

ci = ipf3D(r_full, D, w)
di = ci - dataRescale
chi2 = np.sum(di**2)

print(f"Initial Error={chi2:.2e}")
print("      Ci | Data")
print(f"Min: ({ci.min():.2f},{dataRescale.min():.2f})")
print(f"Max: ({ci.max():.2f},{dataRescale.max():.2f})")
print(f"Overall Delta: {np.linalg.norm(ci - dataRescale):.2f}")

# -------------------------------------------------
# Optimize particle positions
# -------------------------------------------------
nr, delchi2 = 0, 1e99
mindelchi2, maxnr = 1, 5

while abs(delchi2) > mindelchi2 and nr < maxnr:
    dpx, dpy, dpz = cidp23D(cxyz, over, di, Np, D, w)
    px, py, pz = px + dpx, py + dpy, pz + dpz

    cxyz, over = pgrid3D(px, py, pz, Nx, Ny, Nz, Np, os, 0)
    r_full = np.sqrt(cxyz.x**2 + cxyz.y**2 + cxyz.z**2)
    r_full = np.transpose(r_full, (1, 0, 2))

    ci = ipf3D(r_full, D, w)
    di = ci - dataRescale

    delchi2 = chi2 - np.sum(di**2)
    chi2 -= delchi2

    print(".", end="")
    nr += 1
print(f"\nFinal Chi-Squared={chi2:.2e}")

# -------------------------------------------------
# Max projection
# -------------------------------------------------
plt.figure()
plt.imshow(np.max(di**2, axis=2), cmap="gray")
plt.axis("equal")
plt.title(f"Residual chi-squared (max projection), Chi^2={chi2:6.2f}")
plt.show()

# -------------------------------------------------
# Scatter plot of particle centers
# -------------------------------------------------
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
ax.scatter(px, py, pz, c="r", s=50, marker="o")
ax.set_xlim([1, Nx]); ax.set_ylim([1, Ny]); ax.set_zlim([1, Nz])
ax.set_xlabel("X"); ax.set_ylabel("Y"); ax.set_zlabel("Z")
ax.set_title("Detected particle centers")
plt.show()

# -------------------------------------------------
# Rescale data
# -------------------------------------------------
peaks = np.column_stack((py.ravel(), px.ravel(), pz.ravel())) / scaleFactor
dataFinal = resize(dataRescale,
                   output_shape=np.round(np.array(dataRescale.shape) / scaleFactor).astype(int),
                   order=1, preserve_range=True, anti_aliasing=True)

print(dataFinal.shape)
print(dataFinal[:10, :10, :10])

print(f"Type: {type(peaks)}")
print(f"Type: {type(dataFinal)}")

print("Stuff:", dataFinal[:10, 40, 1])
print("Stuff:", dataFinal.size)

# -------------------------------------------------
# Save prediction data
# -------------------------------------------------
peakLoc = "./DataOutput/Peaks_Scan_4.mat"
volLoc = "./DataOutput/Volume_Scan_4.h5"

if not os.path.exists("DataOutput"):
    os.makedirs("DataOutput")
if os.path.isfile(peakLoc) and os.path.isfile(volLoc):
    os.remove(peakLoc)
    os.remove(volLoc)

from scipy.io import savemat
savemat(peakLoc, {"peaks": peaks})
with h5py.File(volLoc, "w") as f:
    f.create_dataset("/dataFinal", data=dataFinal)

print("Saved predictions")

# whos equivalent (variable info)
print("dataFinal:", dataFinal.shape, dataFinal.dtype)
print("peaks:", peaks.shape, peaks.dtype)

# Show contents of saved files
from scipy.io import whosmat
print(whosmat("./DataOutput/Peaks_Scan_4.mat"))
with h5py.File(volLoc, "r") as f:
    print(list(f.keys()))
