In [1]:
import jax.numpy as np
import jax.lax as jl
import jax
import re
import difflib

In [2]:
def get_mhlo(
        func: callable, 
        *args: object, 
        **kwargs: object
    ) -> str:
    mhlo: str = func.lower(*args, **kwargs).compile().as_text()
    return re.sub("metadata={.*}", "", mhlo)

In [3]:
def print_diff(mhlo: str, comp: str) -> list:
    mhlo_lines: list = mhlo.splitlines()
    comp_lines: list = comp.splitlines()
    diff: iter = difflib.unified_diff(mhlo_lines, comp_lines)
    for line in diff:
        print(line)

In [4]:
def mesh(npix: int) -> float:
    centre: float = (npix - 1.0) / 2.0
    shape: tuple =  (1, npix, npix) 
    x: float = jl.broadcasted_iota(float, shape, 1)
    y: float = jl.broadcasted_iota(float, shape, 2)
    return jl.concatenate([x, y], 0) - centre

In [5]:
def hypotenuse(coordinates: float) -> float:
    pythags: float = jl.integer_pow(coordinates, 2)
    return jl.reduce(pythags, 0., jl.add, [0])

In [26]:
radius: float = 1.
npix: int = 1024
nsoft: int = 3
x: float = 0.
y: float = 0.
rotation: float = 0.
pixel_scale: float = 2. * radius / npix

In [27]:
def circ_ap_func(
        radius: float, 
        x: float, 
        y: float,
        rotation: float, 
        nsoft: float,
        pixel_scale: float) -> float:
    # Passing arguments to safe types. 
    centre: float = np.asarray([x, y]).astype(float)
    radius: float = np.asarray(radius).astype(float)
    rotation: float = np.asarray(rotation).astype(float)
    nsoft: float = np.asarray(nsoft).astype(float)
    
    # Organising coords
    ccoords: float = mesh(npix) / npix * 2.0 * radius
    
    # Translation 
    ccoords: float = ccoords - centre[:, None, None]
        
    # Rotation
    sin_alpha: float = jl.sin(rotation)
    cos_alpha: float = jl.cos(rotation)
    x: float = jl.index_in_dim(ccoords, 0)
    y: float = jl.index_in_dim(ccoords, 1)
    new_x: float = x * cos_alpha - y * sin_alpha
    new_y: float = x * sin_alpha + y * cos_alpha
    ccoords: float = jl.concatenate([new_x, new_y], 0)        
        
    # Transformation 
    rho: float = hypotenuse(ccoords)
        
    # Linear softening
    distances: float = radius - rho
    lower: float = jl.full_like(distances, 0., dtype=float)
    upper: float = jl.full_like(distances, 1., dtype=float)
    inside: float = jl.max(distances, lower)
    scaled: float = inside / nsoft / pixel_scale
    aperture: float = jl.min(scaled, upper)
    return aperture

In [28]:
dynamic: callable = jax.jit(circ_ap_func, inline=True)
static: callable = jax.jit(circ_ap_func, inline=True, static_argnums=(1, 2, 3, 4, 5))

In [33]:
%%timeit
dynamic(radius, x, y, rotation, nsoft, pixel_scale).block_until_ready()

4.4 ms ± 97.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [34]:
%%timeit
static(radius, x, y, rotation, nsoft, pixel_scale).block_until_ready()

4.42 ms ± 174 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [13]:
dynamic_mhlo: str = get_mhlo(static, radius, x, y, rotation, nsoft, pixel_scale)
static_mhlo: str = get_mhlo(dynamic, radius, x, y, rotation, nsoft, pixel_scale)

In [14]:
print_diff(dynamic_mhlo, static_mhlo)

--- 

+++ 

@@ -1,83 +1,112 @@

-HloModule jit_circ_ap_func, entry_computation_layout={(f32[])->f32[1024,1024]{1,0}}, allow_spmd_sharding_propagation_to_output=true
+HloModule jit_circ_ap_func, entry_computation_layout={(f32[],f32[],f32[],f32[],s32[],f32[])->f32[1024,1024]{1,0}}, allow_spmd_sharding_propagation_to_output=true
 
-%region_0.40 (Arg_0.41: f32[], Arg_1.42: f32[]) -> f32[] {
-  %Arg_0.41 = f32[] parameter(0)
-  %Arg_1.42 = f32[] parameter(1)
-  ROOT %add.43 = f32[] add(f32[] %Arg_0.41, f32[] %Arg_1.42), 
+%fused_computation.2 (param_0.10: f32[], param_1.18: f32[]) -> f32[2] {
+  %param_1.18 = f32[] parameter(1)
+  %reshape.6 = f32[1]{0} reshape(f32[] %param_1.18), 
+  %param_0.10 = f32[] parameter(0)
+  %reshape.5 = f32[1]{0} reshape(f32[] %param_0.10), 
+  ROOT %concatenate.1 = f32[2]{0} concatenate(f32[1]{0} %reshape.6, f32[1]{0} %reshape.5), dimensions={0}, 
 }
 
-%parallel_reduce.44 (p: f32[2,1024,1024], p.1: f32[]) -> f32[1024,1024] {
+%region_0.53 (Arg_0.54: f32[], Ar