Skip to content

Feat: Two ravel implementations of naive all-sky imaging#43

Merged
Dantali0n merged 6 commits into
mainfrom
imaging-jax-numba
Mar 16, 2026
Merged

Feat: Two ravel implementations of naive all-sky imaging#43
Dantali0n merged 6 commits into
mainfrom
imaging-jax-numba

Conversation

@Dantali0n
Copy link
Copy Markdown
Contributor

@Dantali0n Dantali0n commented Feb 5, 2026

This add two implementations of the most basic all-sky imaging algorithm (mapping an image plane onto a celestial sphere) this algorithm uses antenna pair baselines (UVW) combined with correlated Visibilities (Sometimes called XSTs in the context of LOFAR) to produce a 2D image.

The algorithm closely relates to a 2D FFT and is often parallelizeable with good scaling and speedup.

in this basic implementation we add two versions one using the popular library Numba and the other using the less well known library Jax.

In earlier trials we have observed that Jax typically achieves orders of magnitude better performance compared to Numba hence our reasoning for adding these.

The goal is to also add Jax implementations for the IDG Python kernels.

I will fix the packaging of this sub project and adding the linting to pre-commit once #42 is merged.

Below is a table with some rudimentary benchmarks, the environment wasn't noise free so take them with some salt.

Device Algorithm JIT X Y Min Max Mean std dev Joules / It
155H Numpy Ravel Real N.A 256 256 6.29 6.51 6.39 0.07 187.18
155H Numba Ravel Real Yes 256 256 2.65 2.83 2.71 0.04 149.49
155H Jax Ravel Real Yes 256 256 0.51 0.56 0.53 0.01 22.48
1050 TI Jax Ravel Real Yes 256 256 0.058 0.0788 0.073 0.006
7900 XTX Jax Ravel Real Yes 256 256 0.0059 0.0333 0.0073 0.0049 1.38

@Dantali0n Dantali0n self-assigned this Feb 5, 2026
Comment thread imaging/python/pyproject.toml Outdated
@csbnw
Copy link
Copy Markdown
Contributor

csbnw commented Feb 6, 2026

Consider renaming imaging (which may be a bit too generic) to ravel?

@csbnw
Copy link
Copy Markdown
Contributor

csbnw commented Feb 6, 2026

How did you measure runtime and energy?
I would suggest adding something like timeit that we have for the IDG Python code. If you have another way to measure the runtime (preferably of individual sub-parts of the code), that is also fine by me. As long as we have a way to upload the results to Bencher later, see also here.

@csbnw
Copy link
Copy Markdown
Contributor

csbnw commented Feb 6, 2026

Finally, as someone unfamiliar with Jax (other than what you told me about it), I don't know how to run the code using a GPU. This is a great opportunity to (briefly) describe how it works on the wiki.

@Dantali0n
Copy link
Copy Markdown
Contributor Author

Dantali0n commented Feb 6, 2026

Consider renaming imaging (which may be a bit too generic) to ravel?

'ravel' is also to generic, its the process of mapping a multi dimensional array and then flattening it into a vector. See: https://numpy.org/doc/stable/reference/generated/numpy.ravel.html

The process of ravel in Numpy (and Jax) is very interesting because you can use the result of ravel() as indices for assignment like so:

img = np.zeros((X, Y,))
data = ....
img[data.ravel()] = data

Perphaps we can call it all-sky or something like that.

@Dantali0n
Copy link
Copy Markdown
Contributor Author

How did you measure runtime and energy? I would suggest adding something like timeit that we have for the IDG Python code. If you have another way to measure the runtime (preferably of individual sub-parts of the code), that is also fine by me. As long as we have a way to upload the results to Bencher later, see also here.

Measuring sub parts of code is not possible with Jax and timeit, those instrumentations are not pure and will be optimized out when the function is jitted. You can only measure wall time of the Jitted function. Instead, you can use the Jax tracers to analyze the performance at a sub function level.

See: https://docs.jax.dev/en/latest/tracing.html

@Dantali0n
Copy link
Copy Markdown
Contributor Author

Dantali0n commented Feb 6, 2026

How did you measure runtime and energy? I would suggest adding something like timeit that we have for the IDG Python code. If you have another way to measure the runtime (preferably of individual sub-parts of the code), that is also fine by me. As long as we have a way to upload the results to Bencher later, see also here.

Benchmarking is done using Unit tests in the lofty project see: https://git.astron.nl/bassa/lofty/-/blob/main/tests/test_xst_imaging.py?ref_type=heads#L55

And the measurement class is in: https://gitlab.dantalion.nl/astron/loafty/-/blob/etrs-to-baselines/tests/measurements/measure.py#L61

Runtime energy is measured using PMT and RAPL, note that the lofty repo does not have the PMT integration only my fork does.

@Dantali0n
Copy link
Copy Markdown
Contributor Author

Finally, as someone unfamiliar with Jax (other than what you told me about it), I don't know how to run the code using a GPU. This is a great opportunity to (briefly) describe how it works on the wiki.

I'd rather not regurgitate existing documentation of external projects as A; it has a tendency to go out of date, B; has already been written down by someone else much more familiar with the topic then I am. But I would be okay with creating a Wiki page with several references and pointers on the subject that link to the official documentation.

@Dantali0n Dantali0n requested a review from csbnw February 6, 2026 14:03
@csbnw
Copy link
Copy Markdown
Contributor

csbnw commented Feb 13, 2026

Let me reply to the individual topics in one go:

  • I like the name 'all-sky'.
  • In the IDG Python code,we also measure the runtime of the Jitted function, which seems to work just fine.
  • It doesn't really matter what timing method you use, as long as we can export to JSON and eventually upload to Bencher.
  • I now realise that this imager doesn't really have sub-parts to measure separately, so just having the timing of sky_imager_*_ravel is ok.
  • Should we Lofty for the IDG Python code too?
  • I agree with you that regurgitation is undesirable, but having a (concise) wiki page with some pointers is still very helpful.

@Dantali0n
Copy link
Copy Markdown
Contributor Author

Dantali0n commented Feb 16, 2026

Lofty IDG?

I don't fully understand this question, could you elaborate? "Should we Lofty for the IDG Python code too?"

Timing

All the timing and measuring problems indeed are easily overcome, we can just use timeit.
There is just one small detail to think about, I will draw it out here:

Some Jax gotcha we need to deal with:
Jax uses asynchronous dispatch when running on accelerated devices like GPU / NPU. When reading the results it is lazily loaded once it becomes available. But the compute call returns immediately. This means it can't be used for timing measurements.

Jax asynchronous dispatch calls can use the block_until_ready() method to await completion. However, none async dispatch calls do not have this method.

The result is that for generalized measurements you will need some wrapping.

if has_attr(obj, 'block_until_ready'):
    obj.block_until_ready()

Now suppose the CPU is really really slow, and both the dispatch and kernel have long completed before the if condition is evaluated, the results of timeit will also absorb the runtime cost of the very slow CPU.

Until now, I have not dealt with this but I think we should, so I propose that for measuring Jax kernels we will create some logic based on the chosen Jax runtime (CPU, GPU, NPU, etc) Where we create two methods and dynamically switch between the one that is measured. That way one method can always calll block_until_ready and the other never does without the cost of the if branch in either definitions.

Wiki

Jax wiki page: https://github.com/astron-rd/PACE/wiki/Jax

@csbnw
Copy link
Copy Markdown
Contributor

csbnw commented Feb 26, 2026

Could we do a warm-up run first and take a timing measurement for a second run?

How about using pytest-benchmark?

Extending https://github.com/astron-rd/PACE/blob/main/idg/python/main.py#L185, something like this could work:

class JAXTimer:
    def __init__(self, description: str, warmup: int = 1):
        self.description = description
        self.warmup = warmup
        self.duration = None
    
    def __enter__(self):
        return self
    
    def __exit__(self, *args):
        pass
    
    def run(self, operation: Callable, *args, **kwargs):
        # warmup
        for _ in range(self.warmup):
            result = operation(*args, **kwargs)
            jax.block_until_ready(result)
        
        # timing
        start = time.perf_counter()
        result = operation(*args, **kwargs)
        jax.block_until_ready(result)
        self.duration = time.perf_counter() - start
        
        print(f"{self.description:<38} {self.duration:>9.6f} s")
        _timings[self.description] = self.duration
        return result
        
    def __call__(self, operation: Callable):
        return self.run(operation)

Or even fancier, add a decorator?

Copy link
Copy Markdown
Contributor

@mickveldhuis mickveldhuis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool stuff! Just mainly left some comments to improving the readability 😁 It's always a little tricky when you decompose a problem further and further to optimise.

Comment thread all-sky/python/all_sky_jax.py Outdated
Comment thread all-sky/python/all_sky_jax.py
Comment thread all-sky/python/all_sky_numba.py Outdated
Comment thread all-sky/python/all_sky_numba.py
Comment thread all-sky/python/all_sky_numba.py Outdated
@Dantali0n
Copy link
Copy Markdown
Contributor Author

Could we do a warm-up run first and take a timing measurement for a second run?

I agree we should but think implementing the benchmarking code is beyond the scope of this merge. Ideally we create some classes and wrappers that can be used form both IDG and all-sky in one go.

@Dantali0n
Copy link
Copy Markdown
Contributor Author

Hello @csbnw and @mickveldhuis I have addressed your review comments hoping we can round this off. Note, I also added a cheeky extra config option to the Jax config that makes it even faster on CPU.

Copy link
Copy Markdown
Contributor

@mickveldhuis mickveldhuis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Dantali0n I'm happy to approve this PR.

@Dantali0n Dantali0n merged commit a4a849d into main Mar 16, 2026
@Dantali0n Dantali0n deleted the imaging-jax-numba branch March 16, 2026 08:54
@bhaaksema bhaaksema added this to the M2 milestone May 7, 2026
@bhaaksema bhaaksema linked an issue May 7, 2026 that may be closed by this pull request
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

D2.2 Initial baseline benchmarks

4 participants