In [None]:
# inpath = 'mm1234-1/i21-157116.nxs'
# outpath = 'mm1234-1/processed/i21-157116_output.nxs'
# dark_image_file = 'mm1234-1/i21-157111.nxs'

inpath = '/dls/i21/data/2025/cm40641-4/i21-437170.nxs'
outpath = '/dls/i21/data/2025/cm40641-4/processed/i21-437170_processed.nxs'
dark_image_file = '/dls/i21/data/2025/cm40641-4/i21-437149.nxs'
energy_resolution = "-0.006"  # eV/px

# I21 Image Processing


Image processing notebook for I21 data analysis.

The following steps will be made:
1. Load scan image and dark image
2. Process dark image and detemine scale factor
3. Subtract dark image from scan image
4. Find elastic line and fit slope
5. Rotate image
6. Fit position of elastic line in rotated image
7. Rescale spectra for energy resolution

**Run**

To run this notebook automatically after a scan, add the scannable to the scan command:
```
 scan ds 1 1 1 andor 20 processor
```
The results will automatically appear in [ispyb](ispyb.diamond.ac.uk) and outputs will appear in the *processed* folder.

**Note**

This notebook is a work in progress and still needs work to provide accurate results in all cases.

To suggest improvements, please contact [dan.porter@diamond.ac.uk]()


In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import h5py
from scipy.ndimage import gaussian_filter, rotate
from lmfit import Model
from lmfit.models import GaussianModel, LinearModel
from mmg_toolbox.fitting import multipeakfit
from hdfmap import create_nexus_map

### 1. Load Images

In [None]:
dark_map = create_nexus_map(dark_image_file)
scan_map = create_nexus_map(inpath)

In [None]:
# 1. Load detector arrays

dark_map = create_nexus_map(dark_image_file)
scan_map = create_nexus_map(inpath)

with dark_map.load_hdf() as nxs:
    detector = dark_map.get_image_path()
    # dark_image = dark_map.get_data(nxs, detector)
    dark_image = nxs[detector][...]
    dark_count_time = dark_map.eval(nxs, 'count_time')
    print(f"Dark Detector: {detector}")
    print(f'Dark image: {dark_image.shape}')
    print(f"dark image intensity: max: {dark_image.max()}, min: {dark_image.min()}")
    print(f"count_time: {dark_count_time}")
    dark_image = np.array([img / t for img, t in zip(dark_image, dark_count_time.reshape(-1))])

with scan_map.load_hdf() as nxs:
    detector = scan_map.get_image_path()
    # scan_image = scan_map.get_data(nxs, detector)
    scan_image = nxs[detector][...]
    scan_count_time = scan_map.eval(nxs, 'count_time')
    print(f"\nScan Detector: {detector}")
    print(f'Scan image: {scan_image.shape}')
    print(f"scan image intensity: max: {scan_image.max()}, min: {scan_image.min()}")
    print(f"count_time: {scan_count_time}")
    scan_image = np.array([img / t for img, t in zip(scan_image, scan_count_time.reshape(-1))])

# with h5py.File(dark_image_file, 'r') as nxs:
#     dark_image_dataset = nxs['/entry1/xcam/data']
#     count_time_dataset = nxs['/entry1/instrument/xcam/count_time']
#     dark_image = dark_image_dataset[()]
#     count_time = count_time_dataset[()]
#     print(f'Dark image: {dark_image.shape}')
#     print(f"dark image intensity: max: {dark_image.max()}, min: {dark_image.min()}")
#     print(f"count_time: {count_time}")
# with h5py.File(inpath, 'r') as nxs:
#     scan_image_dataset = nxs['/entry1/xcam/data']
#     count_time_dataset = nxs['/entry1/instrument/xcam/count_time']
#     scan_image = scan_image_dataset[()]
#     count_time = count_time_dataset[()]
#     print(f'\nScan image: {scan_image.shape}')
#     print(f"count_time: {count_time}")

# Image Intensity cut-offs
cmin, cmax = np.min(scan_image[0]), np.min(scan_image[0]) + np.max(scan_image[0]) / 2
print(f"clim = {cmin}, {cmax}")

# Plot images
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=[16, 4], dpi=80)

ax1.imshow(dark_image[0], vmin=cmin+np.min(dark_image[0]), vmax=cmax+np.min(dark_image[0]))
# ax1b = ax1.twiny()
# ax1b.plot(dark_image[0].mean(axis=1), np.arange(dark_image.shape[1]), 'r-')
# ax1b.margins(x=2)
ax1.set_title('Dark Image');

ax2.imshow(scan_image[0], vmin=cmin, vmax=cmax)
# ax2b = ax2.twiny()
# ax2b.plot(scan_image[0].mean(axis=1), np.arange(scan_image.shape[1]), 'r-')
# ax2b.margins(x=2)
ax2.set_title('Scan Image');

ax3.plot(dark_image[0].mean(axis=1), label='Dark Image')
ax3.plot(scan_image[0].mean(axis=1), label='Scan Image')
ax3.legend()
ax3.set_title('Average pixels')

### 2.1. Remove outliers from Dark image

In [None]:
# 2. Remove outliers

# removeBlips2D 
# SubtractFittedBackgroundOperation 
# called by findDarkDataScaleAndOffset<getImage
# https://github.com/DawnScience/scisoft-core/blob/0dbc6fda20bdd51b3059d901b4e0789173807663/uk.ac.diamond.scisoft.analysis.processing/src/uk/ac/diamond/scisoft/analysis/processing/operations/backgroundsubtraction/SubtractFittedBackgroundOperation.java#L801
def remove_blips_2d(data, plot=False): 
    """
    Remove 'blips' (outlier pixels) from a 2D array using histogram fitting.
    - data: 2D numpy array (will be modified in-place)
    """
    # Step 1: crop detector sides
    copy = 1 * data
    data = copy[0, 10:-10, 10:-10]
    
    # Step 2: generate histogram of detector
    nbins = min(int(data.max()), 1024 * 1024)
    counts, bin_edges = np.histogram(data, bins=nbins)
    bin_centres = (bin_edges[:-1] + bin_edges[1:]) / 2

    # Step 2.5: select only the slice from one position to the let of the highest peak to the right half-max.
    # This is done to get an accurate and reliable gaussian fit of the intensity distribution.
    hist_max_idx = np.argmax(counts)
    fit_counts = counts[hist_max_idx - 1 : len(counts) // 2]
    fit_centres = bin_centres[hist_max_idx - 1 : len(counts) // 2]
    
    # Step 3: fit Gaussian to histogram
    gmodel = GaussianModel()
    params = gmodel.guess(counts, x=bin_centres)
    c = np.argmax(counts)
    params['center'].set(min=bin_centres[c - 1] - 0.5, max=bin_centres[c + 1])
    # result = gmodel.fit(counts, params, x=bin_centres)
    result = gmodel.fit(fit_counts, params, x=fit_centres)
    # print(result.fit_report())
    centre = result.best_values['center']
    sigma = result.best_values['sigma']
    fwhm = result.values['fwhm']
    # fwhm = 2.3548200*sigma
    print(f"Centre: {centre}\nFWHM: {fwhm}") 

    if plot:
        plt.figure()
        plt.plot(bin_centres, counts, label='Dark image histogram')
        plt.plot(fit_centres, fit_counts, '--', label='Fit region')
        plt.plot(bin_centres, result.eval(x=bin_centres), 'r-', label='Gaussian fit')
        plt.xlim([0, 1000])
        plt.xlabel('Dark image intensity')
        plt.ylabel('Counts')
        plt.legend()

    # Step 4: Compute threshold
    thr = centre + 2 * fwhm
    print(f"Blip threshold: {thr}")

    # Step 5: Iterate and replace blips
    # We'll iterate row-wise, but you could also flatten the array and do similar
    shape = data.shape
    data_flat = data.ravel()
    size = data_flat.size
    i = 0
    n_removed = 0
    while i < size:
        if abs(data_flat[i]) >= thr:
            start = i
            # Find the end of this blip region
            while i < size and abs(data_flat[i]) >= thr:
                i += 1
            end = i
            # Use average of previous and next pixel if possible
            prev_val = data_flat[max(0, start-1)]
            next_val = data_flat[min(i, size-1)]
            fill_val = 0.5 * (prev_val + next_val)
            max_val = np.max(data_flat[start:end])
            data_flat[start:end] = fill_val
            new_val = np.max(data_flat[start:end])
            n_removed += 1
            # print(f"Blip removed at {start}:{end} = {np.unravel_index(start, shape)}, fill = {fill_val:.0f}, max = {max_val:.0f}, new = {new_val:.0f}")
        else:
            i += 1
    print(f"Outliers removed: {n_removed}")
    copy[0, 10:-10, 10:-10] = data_flat.reshape(shape)
    return copy

print(f"dark image intensity: max: {dark_image[0].max()}, min: {dark_image[0].min()}, argmax: {np.unravel_index(np.argmax(dark_image), dark_image.shape)}")
dark_image_rm_blips = remove_blips_2d(dark_image)
print('\nafter remove_blips_2d')
print(f'Dark image: {dark_image_rm_blips.shape}')
print(f"dark image intensity: max: {dark_image_rm_blips.max()}, min: {dark_image_rm_blips.min()}, argmax: {np.unravel_index(np.argmax(dark_image_rm_blips), dark_image.shape)}")

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=[12, 4], dpi=80)
ax1.imshow(dark_image[0], vmin=cmin+np.min(dark_image[0]), vmax=cmax+np.min(dark_image[0]))
ax1.set_title('Dark Image')
ax2.imshow(dark_image_rm_blips[0], vmin=cmin+np.min(dark_image[0]), vmax=cmax+np.min(dark_image[0]))
ax2.set_title('Dark Image - After removing Cosmic rays')


### 2.2 Perform smoothing

In [None]:
# 2: Smooth the dark image
# In java, this is a convolution of a 2D gaussian with the image. 
# Scipy's gaussian_filter is a sequence of 1D convolutions, 
# however the difference is minor.
smoothing_sigma = 5  # Set according to needs
smoothed_dark = gaussian_filter(dark_image_rm_blips, sigma=smoothing_sigma)
print(f"Smoothed dark image: {smoothed_dark.shape}")

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=[16, 4], dpi=80)
ax1.imshow(dark_image[0], vmin=cmin+np.min(dark_image[0]), vmax=cmax+np.min(dark_image[0]))
ax1.set_title('Dark Image')
ax2.imshow(smoothed_dark[0], vmin=cmin+np.min(dark_image[0]), vmax=cmax+np.min(dark_image[0]))
ax2.set_title('Smoothed Dark Image')
ax3.plot(dark_image[0].sum(axis=1), label='Dark Image')
ax3.plot(smoothed_dark[0].sum(axis=1), label='Smoothed Dark')
ax3.legend()


### 2.3 Determine scale and offset of dark image

In [None]:
y = np.mean(smoothed_dark[0], axis=1)[50:-50]
step = 10  # re-binning step size
y2 = y[:len(y)- len(y) % step].reshape(-1, step).mean(axis=1)
x2 = np.arange(len(y2)) * step

fig, ax = plt.subplots(1, 1)
ax.plot(y, '-')
ax.plot(x2, y2, '+')
ax2 = ax.twinx()
ax2.plot(np.diff(y))
ax2.plot(x2[:-1], np.diff(y2))
# ax2.plot(np.gradient(y))
ax2.set_ylim(-1, 1)

fit_x = x2[:-1]
fit_y = np.abs(np.diff(y2))
params = {
    'p1_center': fit_x[np.argmax(fit_y)],
    'p1_fwhm': step,
    'p1_height': fit_y.max(),
}
print(params)
result = multipeakfit(fit_x, fit_y, initial_parameters=params, npeaks=1, plot_result=True, print_result=True)

In [None]:
# 4. Determine the scale and offset of the dark image at the drop-point
def findDarkDataScaleAndOffset(image, smoothed_dark) -> tuple[float, float]:
    y = np.mean(smoothed_dark[0], axis=1)[50:-50]  # crop start and end
    drop_idx = np.argmin(np.diff(y))  # find drop-position from differential
    
    # rebin 
    step = 10  # re-binning step size
    y2 = y[:len(y)- len(y) % step].reshape(-1, step).mean(axis=1)
    drop_idx = np.argmin(np.diff(y2)) * step
    
    
    # I've estimated the the FWHM of the drop here, 
    # in Java it is done by finding the distance between cross-points 
    # where the differential is half the minimum.
    drop_fwhm = 100 
    # The width of the fit slice is a factor (given by the model, 5 default and doubled) 
    # of the FWHM of negative peak in the derivative of smoothed and 1-D mean background profile
    user_width_param = 5
    drop_fwhm = user_width_param * drop_fwhm
    fitSlice = (drop_idx-drop_fwhm, drop_idx+drop_fwhm)  
    clippedBackground = smoothed_dark[0, fitSlice[0]+50:fitSlice[1]+50, :]
    clippedImage = image[0, fitSlice[0]+50:fitSlice[1]+50, :]  # == "in"
    offset = clippedImage.mean() - clippedBackground.mean()
    print('Guessed Offset = ', offset)

    plt.figure()
    plt.plot(np.mean(dark_image[0], axis=1)[50:-50], label='dark image profile')
    plt.plot(np.mean(smoothed_dark[0], axis=1)[50:-50], label='smoothed dark profile')
    plt.axvline(drop_idx, c='k', label='drop')
    plt.title('dark image drop')
    plt.legend()
    
    plt.gca().twinx()
    plt.plot(np.diff(y), c='c', label='diff')
    plt.ylim([-5, 5])
    plt.ylabel('diff', c='c')

    # Fit a line to the background vs image intensity
    # the image used here is cropped, with outliers removed
    xvals = clippedBackground.ravel()
    yvals = remove_blips_2d(image)[0, fitSlice[0]+50:fitSlice[1]+50, :].ravel()
    
    model = LinearModel()
    result = model.fit(yvals, x=xvals, slope=1, intercept=offset)
    scale = result.params['slope'].value
    offset = result.params['intercept'].value
    # print(result.fit_report())
    print(f"scale: {scale:.3g}, offset: {offset:.3g}")

    plt.figure()
    plt.plot(xvals, yvals, '.', label='Intensities')
    plt.plot(np.array([np.min(xvals), np.max(xvals)]), result.eval(x=np.array([np.min(xvals), np.max(xvals)])), '-r', label='fit')
    plt.xlabel('Background intensity')
    plt.ylabel('Image intensity')
    plt.title('Background scaling fit')
    plt.legend()
    return scale, offset

scaleOffset = findDarkDataScaleAndOffset(scan_image, smoothed_dark)

### 3. Subtract scaled dark image

In [None]:
# 5. Subtract scaled dark image from scan image
scale, offset = scaleOffset
subtractCrop = (smoothed_dark * scale) + offset
darkFit = np.mean(subtractCrop[0], axis=1)[50:-50]
profile = np.mean(scan_image[0], axis=1)[50:-50]


subtracted = scan_image[0] - subtractCrop[0]
sub_profile = np.mean(subtracted, axis=1)[50:-50]
scmin, scmax = 0, np.max(subtracted) / 10

fig, axes = plt.subplots(2, 2, figsize=[12, 8], dpi=100)
axes[0, 0].imshow(scan_image[0], vmin=cmin, vmax=cmax)
axes[0, 0].set_title('Scan Image')
axes[0, 1].imshow(subtracted, vmin=scmin, vmax=scmax)
axes[0, 1].set_title('Scan Image - Scaled Dark Image')

axes[1, 0].plot(profile, label='scan image')
axes[1, 0].plot(darkFit, label='dark image')
axes[1, 0].legend()
axes[1, 1].plot(sub_profile, label='scan image - scaled dark image')
axes[1, 1].legend()

### 4.1 Find elastic peak region

In [None]:
# 4. Find Region of Interest

# 4.1 remove edges
roi = subtracted[10:-10, 10:-10]

# 4.2 fit integrated intensity along x-axis to determine x-region
result_x = multipeakfit(np.arange(roi.shape[1]), roi.mean(axis=0), npeaks=1, print_result=False, plot_result=True)
x_peak = result_x.p1_center
x_roimin, x_roimax = int(result_x.p1_center - result_x.p1_fwhm), int(result_x.p1_center + result_x.p1_fwhm)

roi = roi[:, x_roimin:x_roimax]
xval, yval = np.arange(roi.shape[0]), roi.mean(axis=1)

plt.figure()
plt.imshow(roi, vmin=scmin, vmax=scmax)
plt.plot(yval, xval, 'w-')
ax = plt.gca()

# 4.3 fit integrated intensity along y-axis to determine y-region
result_y = multipeakfit(xval, yval, print_result=False, plot_result=True)
if not hasattr(result_y, 'p1_center'):
    raise Exception('No Peaks found in y-region')

y_roimin, y_roimax = int(result_y.p1_center - result_y.p1_fwhm), int(result_y.p1_center + result_y.p1_fwhm)
ax.axhline(result_y.p1_center, c='r')

roi = roi[y_roimin:y_roimax, :]
xval, yval = np.arange(roi.shape[0]), roi.mean(axis=1)

plt.figure()
plt.imshow(roi, vmin=scmin, vmax=scmax)
plt.plot(yval, xval, 'w-')

# offset values to recover original image index
x_roi_offset = 10 + x_roimin
y_roi_offset = 10 + y_roimin
print(f"ROI offsets: {x_roi_offset}, {y_roi_offset}")

### 4.2 Fit straight line against max pixels in region

In [None]:
# 4. determine line position along x-axis using averaged location of max pixel
step = 100
argmaxes = np.array([
    [
        x + step/2, 
        np.argmax(line := roi[:, x:x+step].sum(axis=1)), 
        line.max()
    ] for x in range(0, roi.shape[1], step)
    if x + step <= roi.shape[1]
])

# Fit straight line against max pixel locations
model = LinearModel()
pars = model.guess(argmaxes[:, 1], argmaxes[:, 0])
linefit = model.fit(argmaxes[:, 1], x=argmaxes[:, 0], params=pars, weights=argmaxes[:, 2])

plt.figure()
plt.imshow(roi, vmin=scmin, vmax=scmax)
plt.plot(argmaxes[:, 0], argmaxes[:, 1], 'y+')
plt.plot(argmaxes[:, 0], linefit.best_fit, 'w-')

In [None]:
# Check fitted line position
x_pos = int(result_x.p1_center) + 10
y_pos = linefit.eval(x=x_pos - x_roi_offset) + y_roi_offset


fig, axes = plt.subplots(2, 2, figsize=[12, 8], dpi=100)
fig.suptitle('Check Elastic Line Fit')
axes[0, 0].imshow(subtracted, vmin=scmin, vmax=scmax)
axes[0, 0].plot(argmaxes[:, 0] + x_roi_offset, linefit.best_fit + y_roi_offset, 'w-', lw=0.5)
axes[0, 1].imshow(subtracted, vmin=scmin, vmax=scmax)
axes[0, 1].plot(argmaxes[:, 0] + x_roi_offset, linefit.best_fit + y_roi_offset, 'w-', lw=0.5)
axes[0, 1].set_xlim(x_pos - 500, x_pos + 100)
axes[0, 1].set_ylim(y_pos + 100, y_pos - 100)

axes[1, 0].plot(subtracted[:, x_pos - step//2: x_pos + step//2].mean(axis=1))
axes[1, 0].axvline(y_pos, c='r', label='bin')
# axes[1, 1].plot(subtracted[:, x_pos - step: x_pos].mean(axis=1))
axes[1, 1].plot(subtracted[:, x_pos - step//2: x_pos + step//2].mean(axis=1))
# axes[1, 1].plot(subtracted[:, x_pos: x_pos + step].mean(axis=1))
axes[1, 1].axvline(y_pos, c='r', label='bin')
axes[1, 1].set_xlim(y_pos - 100, y_pos + 100)

In [None]:
# Fit Results
slope = linefit.params['slope'].value
intercept = linefit.params['intercept'].value + y_roi_offset

print("Line Fit Results:")
print(f"Slope: {slope:.3g} +/- {linefit.params['slope'].stderr:.2g}")
print(f"Interccept: {intercept:.3g} +/- {linefit.params['intercept'].stderr:.2g}")

### 5. Rotate image

In [None]:
# 5. Rotate detectro image by slope
angle = np.arctan(linefit.params['slope'].value)
print(f"Slope angle = {angle:.3g} rad = {np.rad2deg(angle):.3g} deg")
rotated = rotate(subtracted, angle=np.rad2deg(angle))  # uses interpolation
rotated_intercept = intercept * np.cos(angle)

# rotation is about image centre, subtract initial centre
x, y = (
    x_peak - subtracted.shape[1]/2, 
    intercept - subtracted.shape[0]/2
)
# rotate pixels
rx, ry = (x * np.cos(angle) + y * np.sin(angle), -x * np.sin(angle) + y * np.cos(angle))
# add rotated image centre
rotated_intercept = ry + rotated.shape[0]/2
print('Rotated intercept:', rotated_intercept)


fig, axes = plt.subplots(2, 2, figsize=[12, 8], dpi=100)
fig.suptitle('Check Image Rotation')
axes[0, 0].imshow(subtracted, vmin=scmin, vmax=scmax)
axes[0, 0].axhline(intercept, c='r', lw=0.5)
axes[0, 1].imshow(rotated, vmin=scmin, vmax=scmax)
axes[0, 1].axhline(rotated_intercept, c='r', lw=0.5)

axes[1, 0].plot(subtracted[:, x_pos - step//2: x_pos + step//2].mean(axis=1), label='Image - bkg')
axes[1, 0].plot(rotated[:, x_pos - step//2: x_pos + step//2].mean(axis=1), label='Rotated image')
axes[1, 0].axvline(intercept, c='r', label='Elastic line')
axes[1, 0].axvline(rotated_intercept, c='r', label='Elastic line (rotated)')

axes[1, 1].plot(subtracted[:, x_pos - step//2: x_pos + step//2].mean(axis=1), label='Image - bkg')
axes[1, 1].plot(rotated[:, x_pos - step//2: x_pos + step//2].mean(axis=1), label='Rotated image')
axes[1, 1].axvline(intercept, c='r', label='Elastic line')
axes[1, 1].axvline(rotated_intercept, c='r', label='Elastic line (rotated)')
axes[1, 1].set_xlim(y_pos - 100, y_pos + 100)

### 6. Find Elastic line

In [None]:
print(intercept, rotated_intercept)
print(y_roimin - intercept + rotated_intercept, y_roimax - intercept + rotated_intercept, x_roimin, x_roimax)
rotated_roi = rotated[int(y_roimin - intercept + rotated_intercept):int(y_roimax - intercept + rotated_intercept), x_roimin:x_roimax]
yval = np.mean(rotated_roi, axis=1)
xval = np.arange(len(yval))

plt.imshow(rotated_roi, vmin=scmin, vmax=scmax)

result_r = multipeakfit(xval, yval, print_result=False, plot_result=True)
if not hasattr(result_r, 'p1_center'):
    raise Exception('No Peaks found in y-region')
elastic_intercept = result_r.p1_center + y_roimin - intercept + rotated_intercept

fig, axes = plt.subplots(1, 2, figsize=[12, 4], dpi=100)
fig.suptitle('Check Fit Elastic line')
axes[0].imshow(rotated, vmin=scmin, vmax=scmax)
axes[0].axhline(elastic_intercept, c='r', lw=0.5)
axes[1].plot(rotated.mean(axis=1), label='Rotated image')
axes[1].axvline(elastic_intercept, c='r', label='Elastic line')
axes[1].legend()

### 7. Scale by energy resolution

In [None]:
# 7. Scale and sum spectra
energy = (np.arange(rotated.shape[0]) - elastic_intercept ) * float(energy_resolution)
spectra = rotated.sum(axis=1)

plt.figure()
plt.plot(energy, spectra)
plt.xlabel('energy [eV]')
plt.ylabel('Intensity')
plt.title(os.path.basename(inpath))
plt.savefig('/tmp/result.png')

### 8. Write output files

In [None]:
# Write nexus file
# Not done yet!

# Write ascii file
header = f"{inpath}\nProcessed by 'i21_Image_Processing.ipynb'\n energy [eV], intensity"
np.savetxt('/tmp/result.dat', (energy, spectra), header=header)
np.savetxt(outpath[:-4] + '.dat', (energy, spectra), header=header)
if os.path.isfile(outpath[:-4] + '.dat'):
    print(f'Saved file: {outpath[:-4] + '.dat'}')