# Demo 2: benckmarking

In this demo, we show how to benchmark the accuracy and runtime of different propagators on CPU (and GPU if it is available) to gain insight in the computational aspect of them.
We start by setting up the imports.

In [None]:
import math
from time import time
import warnings

from matplotlib import pyplot as plt
import numpy as np
import torch
from torch.special import bessel_j1


from src.psf_generator import *
from psf_generator.utils.misc import convert_tensor_to_array

## Accuracy benchmark

We compare the accuracy of the two scalar propagators `ScalarCartesianPropagator` and `ScalarSphericalPropagator` against a reference - the analytical expression of an Airy disk $F_{\text{AD}}$

$$F_{\mathrm{AD}}(\rho) = \frac{2J_1(\rho)}{\rho}$$

where $J_1$ is the Bessel function of the first order of the first kind.

We measure the $L_2$ error $\delta$ between this expression and the output $E$ of a propagator as a metric for accuracy

$$\delta = \|E - F_{\text{AD}}\|_2.$$

We run it over a range of pixels for the pupil plane and plot the result.

To start, we first define this ground truth in a function.
Note two things:
- to avoid the numerical issue when $\rho$ is small, we approximate the expression with another function $f(\rho) = 1 - \frac{\rho^2}{8}$ when $\rho<10^{-6}$
- to ensure the scaling is correct, we multiple $\rho$ with a factor $\frac{3\pi\text{NA}}{\lambda n}$, where NA is the numerical aperture, $\lambda$ is the wavelength, $n$ is the refractive index of the immersion oil

In [None]:
def airy_disk(fov, n_pix_psf, wavelength, na):
    airy_disk_function = lambda x: torch.where(x > 1e-6, 2 * bessel_j1(x) / x, 1 - x ** 2 / 8)
    x = torch.linspace(- fov / 2, fov / 2, n_pix_psf)
    xx, yy = torch.meshgrid(x, x, indexing='ij')
    rr = torch.sqrt(xx ** 2 + yy ** 2)
    refractive_index = 1.0
    k = 4/3 * refractive_index * math.pi / wavelength
    airy_disk_analytic = convert_tensor_to_array(airy_disk_function(k * rr * na / refractive_index))
    return airy_disk_analytic

Next, we specify some parameters and define the range of pixels for the pupil to benchmark as $[2^3+1, 2^6+1, \ldots, 2^{10}+1]$.
Feel free to modify them based on the specifications of your system.

In [None]:
kwargs = {
        'n_pix_psf': 201,
        'wavelength': 632,
        'na': 1.3,
        'fov': 3000
    }
list_of_pixels = [int(math.pow(2, exponent) + 1) for exponent in range(3, 11)]

Now we are ready to launch the benchmark, run the next cell.

In [None]:
propagator_types = [
        ScalarCartesianPropagator,
        ScalarSphericalPropagator
    ]

airy_disk_analytic = airy_disk(**kwargs)

results = []
for propagator_type in propagator_types:
    propagator_name = propagator_type.get_name()
    accuracy_list = []
    for n_pix in list_of_pixels:
        if 'cartesian' in propagator_type.get_name():
            propagator = propagator_type(n_pix_pupil=n_pix, sz_correction=False, **kwargs)
        elif 'spherical' in propagator_type.get_name():
            propagator = propagator_type(n_pix_pupil=n_pix, cos_factor=True, **kwargs)
        else:
            raise ValueError('incorrect propagator name')

        psf = convert_tensor_to_array(propagator.compute_focus_field())
        psf /= np.max(np.abs(psf))
        accuracy = np.sqrt(np.sum(np.abs(psf - airy_disk_analytic) ** 2))
        accuracy_list.append((n_pix, accuracy))
    results.append(accuracy_list)

Next, we plot the results.

In [None]:
figure, ax = plt.subplots(1, 1, figsize=(6, 6))
colors = ['red', 'blue']
labels = ['Cartesian', 'Spherical']
xs = np.array(list_of_pixels)
for result, label, color in zip(results, labels, colors):
    x, y = zip(*result)
    ax.loglog(x, y, label=label, ls='solid', marker='.', markersize=6, lw=1, color=color)
ax.set_xscale("log", base=2)
ax.set_yscale("log", base=10)
ax.set_xlabel('Pupil size', fontsize=12)
ax.set_ylabel('Error', fontsize=12)
ax.legend(fontsize=12)
plt.grid(color='gray', ls='dotted', lw=1)
plt.show()

## Speed benchmark

To check the scalability of a propagator and compare with other propagators, we can benchmark the runtime to generate a single 2D image of the PSF at the focal plane.
We benchmark against the number of pixels on the pupil plane.
You can modify `device` to select CPU or GPU.

In [None]:
# define propagator types
propagator_types = [
    ScalarCartesianPropagator,
    ScalarSphericalPropagator,
    VectorialCartesianPropagator,
    VectorialSphericalPropagator,
]
# test parameters
list_of_pixels = [int(math.pow(2, exponent) + 1) for exponent in range(3, 11)]
# average the time over many repetitions
number_of_repetitions = 10
# define devices
device = "cpu" # "cuda:0"

In [None]:
if 'cuda' in device:
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    else:
        device = "cpu"
        warnings.warn('GPU not available, benchmarking on CPU instead.')

results = []
for propagator_type in propagator_types:
    average_runtime_list = []
    for n_pix in list_of_pixels:
        print(device, propagator_type.__name__, n_pix)
        runtime_list = []
        for _ in range(number_of_repetitions):
            start_time = time()
            propagator = propagator_type(n_pix_pupil=n_pix, device=device)
            propagator.compute_focus_field()
            runtime = time() - start_time
            runtime_list.append(runtime)
        average_runtime_list.append((n_pix, sum(runtime_list) / number_of_repetitions))
    results.append(average_runtime_list)

Next, plot the results.

In [None]:
figure, ax = plt.subplots(1, 1, figsize=(6, 6))
colors = ['red', 'blue', 'red', 'blue']
labels = [propagator_type.__name__ for propagator_type in propagator_types]
for result, label, color in zip(results, labels, colors):
    x, y = zip(*result)
    if 'Scalar' in label:
        ls = 'dotted'
    else:
        ls = 'solid'
    ax.loglog(x[1:], y[1:], label=label.replace("_", " "), ls=ls, marker='.', markersize=6, lw=1, color=color)
    ax.set_xscale("log", base=2)
    ax.set_yscale("log", base=10)
    ax.set_ylabel('Time (s)', fontsize=12)
    ax.legend(fontsize=12)
ax.set_title(f'Runtime on {device}', fontsize=12)
ax.set_xlabel('Pupil size', fontsize=12)

plt.grid(color='gray', ls='dotted', lw=1)
plt.show()