## Performance comparison for deflection angle calculations

To change the grid size, or to change the backend from CPU to GPU, first restart the kernel, then modify the cell below.

In [None]:
import jax

jax.config.update("jax_platform_name", "cpu")
assert jax.default_backend() == "cpu"

num_pix = 60

Run tests (restarting kernel to free up memory may be necessary if many tests were previously run)

In [None]:
from jax import numpy as jnp, block_until_ready

jax.config.update("jax_enable_x64", True)
import numpy as np
import time

from jaxtronomy.LensModel.profile_list_base import (
    lens_class,
    _JAXXED_MODELS as JAXXED_DEFLECTOR_PROFILES,
)
from lenstronomy.LensModel.profile_list_base import lens_class as lens_class_lenstronomy

EPL_LIST = [
    "EPL",
    "EPL_MULTIPOLE_M1M3M4",
    "EPL_MULTIPOLE_M1M3M4_ELL",
    "EPL_MULTIPOLE_M3M4",
    "EPL_MULTIPOLE_M3M4_ELL",
]
MULTIPOLE_LIST = ["MULTIPOLE", "MULTIPOLE_ELL"]

x_jax = jnp.tile(jnp.linspace(-5.0, 5.0, num_pix) + 100, num_pix)
y_jax = jnp.repeat(jnp.linspace(-5.0, 5.0, num_pix) + 100, num_pix)

x = np.tile(np.linspace(-5.0, 5.0, num_pix) + 100, num_pix)
y = np.repeat(np.linspace(-5.0, 5.0, num_pix) + 100, num_pix)

for deflector_profile in JAXXED_DEFLECTOR_PROFILES:
    # Skip these profiles
    if deflector_profile in ["LOS", "LOS_MINIMAL"]:
        continue

    profile_lenstronomy = lens_class_lenstronomy(deflector_profile)
    profile = lens_class(deflector_profile)

    # Get parameter names
    kwargs_lens = profile.upper_limit_default
    if deflector_profile in EPL_LIST:
        kwargs_lens["e1"] = 0.01
        kwargs_lens["e2"] = 0.01
    if deflector_profile in MULTIPOLE_LIST:
        kwargs_lens["m"] = 3
        if deflector_profile == "MULTIPOLE_ELL":
            kwargs_lens["q"] = 0.5

    # Compile code/warmup
    f_x, f_y = profile_lenstronomy.derivatives(x, y, **kwargs_lens)
    f_x, f_y = profile.derivatives(x_jax, y_jax, **kwargs_lens)
    block_until_ready(f_x)
    block_until_ready(f_y)

    # Now time runtime after compilation/warmup
    start_time = time.perf_counter()
    for _ in range(1000):
        f_x, f_y = profile.derivatives(x_jax, y_jax, **kwargs_lens)
        block_until_ready(f_x)
        block_until_ready(f_y)

    middle_time = time.perf_counter()

    for _ in range(1000):
        f_x, f_y = profile_lenstronomy.derivatives(x, y, **kwargs_lens)

    end_time = time.perf_counter()

    jax_execution_time = middle_time - start_time
    lenstronomy_execution_time = end_time - middle_time
    print(
        f"jaxtronomy execution time for {deflector_profile}: {jax_execution_time} seconds"
    )
    print(
        f"lenstronomy execution time for {deflector_profile}: {lenstronomy_execution_time} seconds"
    )
    print(
        f"jaxtronomy is {'{0:.1f}'.format(lenstronomy_execution_time/jax_execution_time)}x faster\n"
    )

    # Additional performance comparison with EPL_NUMBA
    if deflector_profile == "EPL":
        profile_lenstronomy = lens_class_lenstronomy("EPL_NUMBA")
        f_x, f_y = profile_lenstronomy.derivatives(x, y, **kwargs_lens)
        start_time = time.perf_counter()
        for _ in range(10000):
            f_x, f_y = profile_lenstronomy.derivatives(x, y, **kwargs_lens)
        end_time = time.perf_counter()
        numba_execution_time = end_time - start_time

        print(f"jaxtronomy execution time for EPL: {jax_execution_time} seconds")
        print(
            f"lenstronomy execution time for EPL_NUMBA: {numba_execution_time} seconds"
        )
        print(
            f"jaxtronomy is {'{0:.1f}'.format(numba_execution_time/jax_execution_time)}x faster\n"
        )

## Performance comparison for flux calculations

To change the grid size, or to change the backend from CPU to GPU, first restart the kernel, then modify the cell below.

In [None]:
import jax

jax.config.update("jax_platform_name", "cpu")
assert jax.default_backend() == "cpu"

num_pix = 60

Run tests (restarting kernel to free up memory may be necessary if many tests were previously run)

In [None]:
import copy
from jax import numpy as jnp, block_until_ready

jax.config.update("jax_enable_x64", True)
import numpy as np
import time

from jaxtronomy.LightModel.light_model import LightModel
from lenstronomy.LightModel.light_model import LightModel as LightModel_ref
from jaxtronomy.LightModel.light_model_base import (
    _JAXXED_MODELS as JAXXED_SOURCE_PROFILES,
)

x_jax = jnp.tile(jnp.linspace(-5.0, 5.0, num_pix) + 100, num_pix)
y_jax = jnp.repeat(jnp.linspace(-5.0, 5.0, num_pix) + 100, num_pix)

x = np.tile(np.linspace(-5.0, 5.0, num_pix) + 100, num_pix)
y = np.repeat(np.linspace(-5.0, 5.0, num_pix) + 100, num_pix)

for source_profile in JAXXED_SOURCE_PROFILES:

    lightModel = LightModel([source_profile])
    lightModel_ref = LightModel_ref([source_profile])
    kwargs_source = copy.deepcopy(lightModel.func_list[0].upper_limit_default)
    if source_profile in ["MULTI_GAUSSIAN", "MULTI_GAUSSIAN_ELLIPSE"]:
        kwargs_source["amp"] = np.linspace(10, 20, 5)
        kwargs_source["sigma"] = np.linspace(0.3, 1.0, 5)
    elif source_profile == "SHAPELETS":
        # Do this profile at the end
        continue

    # Compile code/warmup
    flux_jax = lightModel.func_list[0].function(x_jax, y_jax, **kwargs_source)
    block_until_ready(flux_jax)

    flux = lightModel_ref.func_list[0].function(x, y, **kwargs_source)

    # Now time runtime after compilation/warmup
    start_time = time.perf_counter()
    for _ in range(10000):
        flux_jax = lightModel.func_list[0].function(x_jax, y_jax, **kwargs_source)
        block_until_ready(flux_jax)

    middle_time = time.perf_counter()

    for _ in range(10000):
        flux = lightModel_ref.func_list[0].function(x, y, **kwargs_source)

    end_time = time.perf_counter()

    jax_execution_time = middle_time - start_time
    lenstronomy_execution_time = end_time - middle_time
    print(
        f"jaxtronomy execution time for {source_profile}: {jax_execution_time} seconds"
    )
    print(
        f"lenstronomy execution time for {source_profile}: {lenstronomy_execution_time} seconds"
    )
    print(
        f"jaxtronomy is {'{0:.1f}'.format(lenstronomy_execution_time/jax_execution_time)}x faster\n"
    )


lightModel = LightModel(["SHAPELETS"])
lightModel_ref = LightModel_ref(["SHAPELETS"])
kwargs_source = copy.deepcopy(lightModel.func_list[0].upper_limit_default)


# Now run comparisons for the SHAPELETS profile for different values of n_max
for n_max in [6, 10]:
    num_param = int((n_max + 1) * (n_max + 2) / 2)
    kwargs_source["n_max"] = n_max
    kwargs_source["amp"] = np.linspace(20.0, 30.0, num_param)

    # Compile code/warmup
    flux_jax = lightModel.func_list[0].function(x_jax, y_jax, **kwargs_source)
    block_until_ready(flux_jax)

    flux = lightModel_ref.func_list[0].function(x, y, **kwargs_source)

    # Now time runtime after compilation/warmup
    start_time = time.perf_counter()
    for _ in range(10000):
        flux_jax = lightModel.func_list[0].function(x_jax, y_jax, **kwargs_source)
        block_until_ready(flux_jax)

    middle_time = time.perf_counter()

    for _ in range(10000):
        flux = lightModel_ref.func_list[0].function(x, y, **kwargs_source)

    end_time = time.perf_counter()

    jax_execution_time = middle_time - start_time
    lenstronomy_execution_time = end_time - middle_time
    print(
        f"jaxtronomy execution time for SHAPELETS with n_max={n_max}: {jax_execution_time} seconds"
    )
    print(
        f"lenstronomy execution time for SHAPELETS with n_max={n_max}: {lenstronomy_execution_time} seconds"
    )
    print(
        f"jaxtronomy is {'{0:.1f}'.format(lenstronomy_execution_time/jax_execution_time)}x faster\n"
    )

## Comparing jaxtronomy and lenstronomy performance for FFT convolution

To change the grid size, kernel size, or to change the backend from CPU to GPU, first restart the kernel, then modify the cell below.

In [None]:
import jax

jax.config.update("jax_platform_name", "cpu")
assert jax.default_backend() == "cpu"

num_pix = 60

kernel_size = 45

Run tests (restarting kernel to free up memory may be necessary if many tests were previously run)

In [None]:
from jax import numpy as jnp

jax.config.update("jax_enable_x64", True)
import numpy as np
import time

from lenstronomy.ImSim.Numerics.convolution import (
    PixelKernelConvolution as PixelKernelConvolution_ref,
)
from jaxtronomy.ImSim.Numerics.convolution import PixelKernelConvolution

kernel_jax = jnp.tile(jnp.linspace(0.9, 1.1, kernel_size), kernel_size).reshape(
    (kernel_size, kernel_size)
)
kernel_jax = kernel_jax / jnp.sum(kernel_jax)
kernel = np.tile(np.linspace(0.9, 1.1, kernel_size), kernel_size).reshape(
    (kernel_size, kernel_size)
)
kernel = kernel / np.sum(kernel)

jax_conv = PixelKernelConvolution(kernel=kernel_jax, convolution_type="fft")
lenstronomy_conv = PixelKernelConvolution_ref(kernel=kernel, convolution_type="fft")

image_jax = jnp.tile(jnp.linspace(1.0, 20.0, num_pix), num_pix).reshape(
    (num_pix, num_pix)
)
image = np.tile(np.linspace(1.0, 20.0, num_pix), num_pix).reshape((num_pix, num_pix))

# Compile code/warmup
result_jax = jax_conv.convolution2d(image_jax)
result = lenstronomy_conv.convolution2d(image)
np.testing.assert_allclose(result_jax, result, rtol=1e-5, atol=1e-5)

# Now time runtime after compilation/warmup
start_time = time.perf_counter()

for _ in range(10000):
    jax_conv.convolution2d(image_jax)

middle_time = time.perf_counter()

for _ in range(10000):
    lenstronomy_conv.convolution2d(image)

end_time = time.perf_counter()

jax_execution_time = middle_time - start_time
lenstronomy_execution_time = end_time - middle_time
print(
    f"jaxtronomy execution time for fft convolution with kernel size {kernel_size}: {jax_execution_time} seconds"
)
print(
    f"lenstronomy execution time for fft convolution with kernel size {kernel_size}: {lenstronomy_execution_time} seconds"
)
print(
    f"jaxtronomy is {'{0:.1f}'.format(lenstronomy_execution_time/jax_execution_time)}x faster\n"
)