<a href="https://colab.research.google.com/github/ColmTalbot/wcosmo/blob/timing-notebook/examples/wcosmo_timing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install wcosmo

Collecting wcosmo
  Downloading wcosmo-0.1.1-py3-none-any.whl (9.4 kB)
Collecting gwpopulation (from wcosmo)
  Downloading gwpopulation-1.1.0-py3-none-any.whl (32 kB)
Collecting bilby>=2.2.0 (from gwpopulation->wcosmo)
  Downloading bilby-2.3.0-py3-none-any.whl (2.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m17.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting cached-interpolate (from gwpopulation->wcosmo)
  Downloading cached_interpolate-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl (402 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m402.1/402.1 kB[0m [31m23.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting bilby.cython>=0.3.0 (from bilby>=2.2.0->gwpopulation->wcosmo)
  Downloading bilby.cython-0.5.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m30.9 MB/s[0m eta [36m0:00:00

## Compare `wcosmo` and `astropy` timing.

The two primary pieces of functionality we use are converting from luminosity distance to redshift, calculating the distance to redshift jacobian, and calculating the differential comoving volume.

Timing the `wcosmo` implementation is non trivial as we rely on JIT compilation and also need to make sure we wait until the evaluation is complete.
The steps are:

- jit compile a wrapper function to call.
- burn an evaluation for the compilation.
- run the function and use `block_until_ready` to ensure we get the full time.

We also time `wcosmo` with the `numpy` and `cupy` backends.
Note that `cupy` also requires burning a call to compile the underlying `CUDA` code.

We manually switch backends, although this can be done automatically using `GWPopulation`.

In [2]:
import numpy as np
import wcosmo


def set_backend(backend):
    from importlib import import_module
    np_modules = dict(
        numpy="numpy",
        jax="jax.numpy",
        cupy="cupy",
    )
    linalg_modules = dict(
        numpy="scipy.linalg",
        jax="jax.scipy.linalg",
        cupy="cupyx.scipy.linalg",
    )
    setattr(wcosmo.wcosmo, "xp", import_module(np_modules[backend]))
    setattr(wcosmo.utils, "xp", import_module(np_modules[backend]))
    toeplitz = getattr(import_module(linalg_modules[backend]), "toeplitz")
    setattr(wcosmo.utils, "toeplitz", toeplitz)


ndata = np.random.uniform(1, 10, 1000000)

### wcosmo + jax + GPU

In [3]:
import jax.numpy as jnp
import numpy as np
from jax import jit


set_backend("jax")

jdata = jnp.array(ndata)


@jit
def time_jax_redshift(jdata):
    return wcosmo.z_at_value(wcosmo.FlatwCDM(67, 0.3, -1).luminosity_distance, jdata)


@jit
def time_jax_dvcdz(jdata):
    return wcosmo.FlatwCDM(67, 0.3, -1).differential_comoving_volume(jdata)


burn_vals = time_jax_redshift(jdata)
burn_vals = time_jax_dvcdz(jdata)

In [4]:
%%time

_ = time_jax_redshift(jdata).block_until_ready()

CPU times: user 1.78 ms, sys: 0 ns, total: 1.78 ms
Wall time: 2.43 ms


In [5]:
%%time

_ = time_jax_dvcdz(jdata).block_until_ready()

CPU times: user 1.52 ms, sys: 0 ns, total: 1.52 ms
Wall time: 1.42 ms


### astropy + cpu

Note that this is very slow in this case so we only use one percent of the full data.
Since this is `numpy`-based the time scales linearly with the amount of data.

In practice, most people when using `astropy` use intepolation to evaluate `z_at_value` at many points as is done in `wcosmo`.

In [6]:
from astropy import cosmology, units

In [7]:
%%time

_ = cosmology.z_at_value(
    cosmology.FlatwCDM(67, 0.3, -1).luminosity_distance,
    ndata[:10000] * units.Mpc,
).value

CPU times: user 32.6 s, sys: 79.3 ms, total: 32.7 s
Wall time: 32.9 s


In [8]:
%%time

_ = cosmology.FlatwCDM(67, 0.3, -1).differential_comoving_volume(
    ndata[:10000],
).value

CPU times: user 103 ms, sys: 0 ns, total: 103 ms
Wall time: 104 ms


### wcosmo + numpy + cpu

In [9]:
set_backend("numpy")

In [10]:
%%time

_ = wcosmo.z_at_value(
    wcosmo.FlatwCDM(67, 0.3, -1).luminosity_distance, ndata
)

CPU times: user 81.4 ms, sys: 64 ms, total: 145 ms
Wall time: 80.9 ms


In [11]:
%%time

_ = wcosmo.FlatwCDM(67, 0.3, -1).differential_comoving_volume(ndata)

CPU times: user 74.6 ms, sys: 75.1 ms, total: 150 ms
Wall time: 93.8 ms


### wcosmo + cupy + gpu

The final test is using the `cupy` backend on the GPU.
Typically this is much faster than `numpy` but slower than the `JAX` GPU code.
Especially, not tested here is transfer between CPU/GPU which can be quite slow for `cupy`.

In [12]:
import cupy

set_backend("cupy")

cdata = cupy.asarray(ndata)

_ = wcosmo.z_at_value(
    wcosmo.FlatwCDM(67, 0.3, -1).luminosity_distance, cdata
)
_ = wcosmo.FlatwCDM(67, 0.3, -1).differential_comoving_volume(cdata)

In [13]:
%%time

_ = wcosmo.z_at_value(
    wcosmo.FlatwCDM(67, 0.3, -1).luminosity_distance, cdata
)
cupy.cuda.stream.get_current_stream().synchronize()

CPU times: user 6.31 ms, sys: 0 ns, total: 6.31 ms
Wall time: 6.73 ms


In [14]:
%%time

_ = wcosmo.FlatwCDM(67, 0.3, -1).differential_comoving_volume(cdata)
cupy.cuda.stream.get_current_stream().synchronize()

CPU times: user 109 ms, sys: 935 µs, total: 110 ms
Wall time: 124 ms
