# EXtra-foam azimuthal integration benchmark

In [None]:
import os.path as osp

import numpy as np
from pyFAI.azimuthalIntegrator import AzimuthalIntegrator as PyfaiAzimuthalIntegrator
from scipy.signal import find_peaks
import matplotlib.pyplot as plt

import extra_foam
print(extra_foam.__version__)

from extra_foam.algorithms import AzimuthalIntegrator, ConcentricRingsFinder
from extra_foam.algorithms import mask_image_data

In [None]:
def load_image(filepath):
    img = np.load(osp.join(osp.expanduser('~'), filepath))
    mask = np.zeros_like(img, dtype=bool)
    mask_image_data(img, out=mask)
    _, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(img)
    
    return img, mask

# img, mask = load_image("jf_ring.npy")
img, mask = load_image("jf_ring_6modules.npy")
# img, mask = load_image("lpd.npy")

In [None]:
dist = 1  # sample distance
npt = 1024  # number of integration points
pixel1, pixel2 = 0.75e-6, 0.75e-6  # pixel size (y, x)
cy, cx = 530, 1125 
poni1, poni2 = 530 * pixel1, 1125 * pixel2  # integration center (y, x)

In [None]:
# %%timeit

pyfai_integrator = PyfaiAzimuthalIntegrator(
    dist=dist, poni1=poni1, poni2=poni2, pixel1=pixel1, pixel2=pixel2, wavelength=1e-10)

q_gt, I_gt = pyfai_integrator.integrate1d(img, npt, mask=mask, unit="q_A^-1")

In [None]:
# %%timeit

integrator = AzimuthalIntegrator(
    dist=dist, poni1=poni1, poni2=poni2, pixel1=pixel1, pixel2=pixel2, wavelength=1e-10)

q, I = integrator.integrate1d(img, npt=npt)

In [None]:
_, ax = plt.subplots(figsize=(12, 6))

ax.plot(1e-10 * q, I, '-', label='EXtra-foam')
ax.plot(q_gt, I_gt, '--', label='pyFAI')
ax.set_xlabel("q (1/A)", fontsize=16)
ax.set_ylabel("I (arb.)", fontsize=16)
ax.legend(fontsize=16)

In [None]:
# %%timeit

min_count = 500
prominence = 100
distance = 10

finder = ConcentricRingsFinder(pixel2, pixel1)
cx, cy = finder.search(img, cx, cy, min_count=min_count)

In [None]:
q, s = finder.integrate(img, cx, cy, min_count=min_count)

i_peaks = find_peaks(s, distance=distance, prominence=prominence)[0]

_, ax = plt.subplots(figsize=(12, 6))

ax.plot(q, s, '-')
ax.plot(q[i_peaks], s[i_peaks], 'x')
ax.set_xlabel("q (1/A)", fontsize=16)
ax.set_ylabel("I (arb.)", fontsize=16)

print("Optimized cx = ", cx, ", cy = ", cy)