<a href="https://colab.research.google.com/github/ColmTalbot/wcosmo/blob/timing-notebook/examples/wcosmo_timing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install git+https://github.com/ColmTalbot/wcosmo.git

Collecting git+https://github.com/ColmTalbot/wcosmo.git
  Cloning https://github.com/ColmTalbot/wcosmo.git to /tmp/pip-req-build-aitnw024
  Running command git clone --filter=blob:none --quiet https://github.com/ColmTalbot/wcosmo.git /tmp/pip-req-build-aitnw024
  Resolved https://github.com/ColmTalbot/wcosmo.git to commit 7a44bce6fce29e1aa35007cf1a97b0747106c104
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


## Compare `wcosmo` and `astropy` timing.

The two primary pieces of functionality we use are converting from luminosity distance to redshift, calculating the distance to redshift jacobian, and calculating the differential comoving volume.

Timing the `wcosmo` implementation is non trivial as we rely on JIT compilation and also need to make sure we wait until the evaluation is complete.
The steps are:

- jit compile a wrapper function to call.
- burn an evaluation for the compilation.
- run the function and use `block_until_ready` to ensure we get the full time.

We also time `wcosmo` with the `numpy` and `cupy` backends.
Note that `cupy` also requires burning a call to compile the underlying `CUDA` code.

We manually switch backends, although this can be done automatically using `GWPopulation`.

In [None]:
import numpy as np
import wcosmo


def set_backend(backend):
    from importlib import import_module
    np_modules = dict(
        numpy="numpy",
        jax="jax.numpy",
        cupy="cupy",
    )
    linalg_modules = dict(
        numpy="scipy.linalg",
        jax="jax.scipy.linalg",
        cupy="cupyx.scipy.linalg",
    )
    setattr(wcosmo.wcosmo, "xp", import_module(np_modules[backend]))
    setattr(wcosmo.utils, "xp", import_module(np_modules[backend]))
    toeplitz = getattr(import_module(linalg_modules[backend]), "toeplitz")
    setattr(wcosmo.utils, "toeplitz", toeplitz)


ndata = np.random.uniform(1, 10, 1000000)

### wcosmo + jax + GPU

In [None]:
import jax.numpy as jnp
import numpy as np
from jax import jit


set_backend("jax")

jdata = jnp.array(ndata)


@jit
def time_jax_redshift(jdata):
    return wcosmo.z_at_value(wcosmo.FlatwCDM(67, 0.3, -1).luminosity_distance, jdata)


@jit
def time_jax_dvcdz(jdata):
    return wcosmo.FlatwCDM(67, 0.3, -1).differential_comoving_volume(jdata)


burn_vals = time_jax_redshift(jdata)
burn_vals = time_jax_dvcdz(jdata)

In [None]:
%%time

_ = time_jax_redshift(jdata).block_until_ready()

CPU times: user 732 µs, sys: 222 µs, total: 954 µs
Wall time: 6.78 ms


In [None]:
%%time

_ = time_jax_dvcdz(jdata).block_until_ready()

CPU times: user 0 ns, sys: 627 µs, total: 627 µs
Wall time: 638 µs


### astropy + cpu

Note that this is very slow in this case so we only use one percent of the full data.
Since this is `numpy`-based the time scales linearly with the amount of data.

In practice, most people when using `astropy` use intepolation to evaluate `z_at_value` at many points as is done in `wcosmo`.

In [None]:
from astropy import cosmology, units

In [None]:
%%time

_ = cosmology.z_at_value(
    cosmology.FlatwCDM(67, 0.3, -1).luminosity_distance,
    ndata[:10000] * units.Mpc,
).value

CPU times: user 35.8 s, sys: 130 ms, total: 36 s
Wall time: 43.6 s


In [None]:
%%time

_ = cosmology.FlatwCDM(67, 0.3, -1).differential_comoving_volume(
    ndata[:10000],
).value

CPU times: user 177 ms, sys: 942 µs, total: 178 ms
Wall time: 181 ms


### wcosmo + numpy + cpu

In [None]:
set_backend("numpy")

In [None]:
%%time

_ = wcosmo.z_at_value(
    wcosmo.FlatwCDM(67, 0.3, -1).luminosity_distance, ndata
)

CPU times: user 75.3 ms, sys: 2.95 ms, total: 78.2 ms
Wall time: 92.2 ms


In [None]:
%%time

_ = wcosmo.FlatwCDM(67, 0.3, -1).differential_comoving_volume(ndata)

CPU times: user 109 ms, sys: 18.9 ms, total: 128 ms
Wall time: 130 ms


### wcosmo + cupy + gpu

The final test is using the `cupy` backend on the GPU.
Typically this is much faster than `numpy` but slower than the `JAX` GPU code.
Especially, not tested here is transfer between CPU/GPU which can be quite slow for `cupy`.

In [None]:
import cupy

set_backend("cupy")

cdata = cupy.asarray(ndata)

_ = wcosmo.z_at_value(
    wcosmo.FlatwCDM(67, 0.3, -1).luminosity_distance, cdata
)
_ = wcosmo.FlatwCDM(67, 0.3, -1).differential_comoving_volume(cdata)

In [None]:
%%time

_ = wcosmo.z_at_value(
    wcosmo.FlatwCDM(67, 0.3, -1).luminosity_distance, cdata
)
cupy.cuda.stream.get_current_stream().synchronize()

CPU times: user 8.13 ms, sys: 0 ns, total: 8.13 ms
Wall time: 7.54 ms


In [None]:
%%time

_ = wcosmo.FlatwCDM(67, 0.3, -1).differential_comoving_volume(cdata)
cupy.cuda.stream.get_current_stream().synchronize()

CPU times: user 109 ms, sys: 26 µs, total: 109 ms
Wall time: 110 ms
