Feat: Two ravel implementations of naive all-sky imaging#43
Conversation
|
Consider renaming |
|
How did you measure runtime and energy? |
|
Finally, as someone unfamiliar with Jax (other than what you told me about it), I don't know how to run the code using a GPU. This is a great opportunity to (briefly) describe how it works on the wiki. |
'ravel' is also to generic, its the process of mapping a multi dimensional array and then flattening it into a vector. See: https://numpy.org/doc/stable/reference/generated/numpy.ravel.html The process of ravel in Numpy (and Jax) is very interesting because you can use the result of img = np.zeros((X, Y,))
data = ....
img[data.ravel()] = dataPerphaps we can call it |
Measuring sub parts of code is not possible with Jax and timeit, those instrumentations are not pure and will be optimized out when the function is jitted. You can only measure wall time of the Jitted function. Instead, you can use the Jax tracers to analyze the performance at a sub function level. |
Benchmarking is done using Unit tests in the lofty project see: https://git.astron.nl/bassa/lofty/-/blob/main/tests/test_xst_imaging.py?ref_type=heads#L55 And the measurement class is in: https://gitlab.dantalion.nl/astron/loafty/-/blob/etrs-to-baselines/tests/measurements/measure.py#L61 Runtime energy is measured using PMT and RAPL, note that the lofty repo does not have the PMT integration only my fork does. |
I'd rather not regurgitate existing documentation of external projects as A; it has a tendency to go out of date, B; has already been written down by someone else much more familiar with the topic then I am. But I would be okay with creating a Wiki page with several references and pointers on the subject that link to the official documentation. |
|
Let me reply to the individual topics in one go:
|
Lofty IDG?I don't fully understand this question, could you elaborate? "Should we Lofty for the IDG Python code too?" TimingAll the timing and measuring problems indeed are easily overcome, we can just use Some Jax gotcha we need to deal with: Jax asynchronous dispatch calls can use the The result is that for generalized measurements you will need some wrapping. if has_attr(obj, 'block_until_ready'):
obj.block_until_ready()Now suppose the CPU is really really slow, and both the dispatch and kernel have long completed before the Until now, I have not dealt with this but I think we should, so I propose that for measuring Jax kernels we will create some logic based on the chosen Jax runtime (CPU, GPU, NPU, etc) Where we create two methods and dynamically switch between the one that is measured. That way one method can always calll WikiJax wiki page: https://github.com/astron-rd/PACE/wiki/Jax |
|
Could we do a warm-up run first and take a timing measurement for a second run? How about using pytest-benchmark? Extending https://github.com/astron-rd/PACE/blob/main/idg/python/main.py#L185, something like this could work: Or even fancier, add a decorator? |
mickveldhuis
left a comment
There was a problem hiding this comment.
Cool stuff! Just mainly left some comments to improving the readability 😁 It's always a little tricky when you decompose a problem further and further to optimise.
I agree we should but think implementing the benchmarking code is beyond the scope of this merge. Ideally we create some classes and wrappers that can be used form both IDG and all-sky in one go. |
|
Hello @csbnw and @mickveldhuis I have addressed your review comments hoping we can round this off. Note, I also added a cheeky extra config option to the Jax config that makes it even faster on CPU. |
mickveldhuis
left a comment
There was a problem hiding this comment.
@Dantali0n I'm happy to approve this PR.
This add two implementations of the most basic all-sky imaging algorithm (mapping an image plane onto a celestial sphere) this algorithm uses antenna pair baselines (UVW) combined with correlated Visibilities (Sometimes called XSTs in the context of LOFAR) to produce a 2D image.
The algorithm closely relates to a 2D FFT and is often parallelizeable with good scaling and speedup.
in this basic implementation we add two versions one using the popular library Numba and the other using the less well known library Jax.
In earlier trials we have observed that Jax typically achieves orders of magnitude better performance compared to Numba hence our reasoning for adding these.
The goal is to also add Jax implementations for the IDG Python kernels.
I will fix the packaging of this sub project and adding the linting to pre-commit once #42 is merged.
Below is a table with some rudimentary benchmarks, the environment wasn't noise free so take them with some salt.