In [104]:
import jax
import jax.numpy as jnp
from jax import jit, vmap, grad
import jax.scipy as jsp

import jaxlib.xla_extension as xla_ext

Example Ref - https://github.com/google/jax/discussions/7068
```python
# You could use the presets
option = xla_ext.HloPrintOptions.short_parsable()
print(module.to_string(option))

option = xla_ext.HloPrintOptions.canonical()
print(module.to_string(option))

option = xla_ext.HloPrintOptions.fingerprint()
print(module.to_string(option))

# Or set each option manually
option = xla_ext.HloPrintOptions()
option.print_metadata = False
option.include_layout_in_shapes = False
print(module.to_string(option))
```

In [8]:
def f(x):
  return jnp.sin(x)

c = jax.xla_computation(f)(1.)

backend = jax.lib.xla_bridge.get_backend()
computy = backend.compile(c)

In [20]:
def b(a, b, c):
    return a * b * c

In [32]:
def matty(a, b):
  return jnp.matmul(a,b) + 1 - 1 + 1

In [97]:
def get_optimized_func(f, args, static_argnums=None):
  if static_argnums is not None:
    c = jax.xla_computation(f, static_argnums=static_argnums)(*args)
  else:
    c = jax.xla_computation(f)(*args)

  backend = jax.lib.xla_bridge.get_backend()
  e = backend.compile(c)
  # hlo python object binding lives at
  # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla_compiler.cc
  print_options = xla_ext.HloPrintOptions
  standard_option_types = [print_options.short_parsable(), print_options.canonical(), print_options.fingerprint()]
  option = standard_option_types[0]
  #option = xla_ext.HloPrintOptions()
  #option.print_metadata = False
  #option.include_layout_in_shapes = False
  return e
  #return e.hlo_modules()[0].to_string(option)


In [84]:
print(get_optimized_func(f, (1.0,)))

<jaxlib.xla_extension.Executable object at 0x7f4118720630>


<function NoneType.__dir__>

In [103]:
jax.lax.conv(jnp.ones((5,5,3)), jnp.ones((5,5,3)), (1,1), "SAME")

ValueError: ignored

In [110]:
x = jnp.linspace(-3, 3, 7)
window = jsp.stats.norm.pdf(x) * jsp.stats.norm.pdf(x[:, None])
input_img = jnp.ones((32,32))
print(get_optimized_func(
    jsp.signal.convolve, 
    (input_img, window)
    ).hlo_modules()[0].to_string())

HloModule xla_computation_convolve.143, entry_computation_layout={(f32[32,32]{1,0},f32[7,7]{1,0})->(f32[38,38]{1,0})}

%fused_computation (param_0.2: f32[7,7]) -> f32[7,7,1,1] {
  %param_0.2 = f32[7,7]{1,0} parameter(0)
  %reverse.1 = f32[7,7]{1,0} reverse(f32[7,7]{1,0} %param_0.2), dimensions={0,1}, metadata={op_name="xla_computation(convolve)/jit(main)/jit(_flip)/rev[dimensions=(0, 1)]" source_file="<ipython-input-97-ae23534e1a9a>" source_line=5}
  %reshape.6 = f32[7,7,1,1]{1,0,3,2} reshape(f32[7,7]{1,0} %reverse.1)
  ROOT %copy.2 = f32[7,7,1,1]{3,2,1,0} copy(f32[7,7,1,1]{1,0,3,2} %reshape.6)
}

%fused_computation.1 (param_0.4: f32[32,32]) -> f32[1,32,32,1] {
  %param_0.4 = f32[32,32]{1,0} parameter(0)
  %reshape.9 = f32[1,32,32,1]{2,1,3,0} reshape(f32[32,32]{1,0} %param_0.4)
  ROOT %copy.3 = f32[1,32,32,1]{3,2,1,0} copy(f32[1,32,32,1]{2,1,3,0} %reshape.9)
}

ENTRY %main.12 (Arg_0.1: f32[32,32], Arg_1.2: f32[7,7]) -> (f32[38,38]) {
  %Arg_0.1 = f32[32,32]{1,0} parameter(0)
  %fusion.

In [100]:

get_optimized_func(
    jax.lax.conv, 
    (jnp.ones((5,5)), jnp.ones((5,5)), 1, "SAME")
    ).hlo_modules()[0].to_string()

TypeError: ignored

In [None]:
print(get_optimized_func(, (1.0,)))

In [41]:
print(get_optimized_func(b, (1.0,2.0, 3.0)))

HloModule xla_computation_b.20, entry_computation_layout={(f32[],f32[],f32[])->(f32[])}

fused_computation {
  param_1.1 = f32[] parameter(1)
  param_2 = f32[] parameter(2)
  multiply.1 = f32[] multiply(param_1.1, param_2)
  param_0.1 = f32[] parameter(0)
  ROOT multiply.0 = f32[] multiply(multiply.1, param_0.1)
}

ENTRY main.7 {
  Arg_2.3 = f32[] parameter(2)
  Arg_0.1 = f32[] parameter(0)
  Arg_1.2 = f32[] parameter(1)
  fusion = f32[] fusion(Arg_2.3, Arg_0.1, Arg_1.2), kind=kLoop, calls=fused_computation
  ROOT tuple.6 = (f32[]) tuple(fusion)
}




In [43]:
print(get_optimized_func(matty, (jnp.ones((5,5)), jnp.ones((5,5))) ))

HloModule xla_computation_matty.22, entry_computation_layout={(f32[5,5]{1,0},f32[5,5]{1,0})->(f32[5,5]{1,0})}

fused_computation {
  param_0 = f32[5,5]{1,0} parameter(0)
  constant.2 = f32[] constant(1)
  broadcast.2 = f32[5,5]{1,0} broadcast(constant.2), dimensions={}
  ROOT add.3 = f32[5,5]{1,0} add(param_0, broadcast.2)
}

ENTRY main.10 {
  Arg_0.1 = f32[5,5]{1,0} parameter(0)
  Arg_1.2 = f32[5,5]{1,0} parameter(1)
  dot.5 = f32[5,5]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
  fusion = f32[5,5]{1,0} fusion(dot.5), kind=kLoop, calls=fused_computation
  ROOT tuple.9 = (f32[5,5]{1,0}) tuple(fusion)
}




In [45]:
print(get_optimized_func(matty, (jnp.ones((5000,5000)), jnp.ones((5000,5000))) ))

HloModule xla_computation_matty.24, entry_computation_layout={(f32[5000,5000]{1,0},f32[5000,5000]{1,0})->(f32[5000,5000]{1,0})}

fused_computation.clone {
  param_0.1 = f32[5000,5000]{1,0} parameter(0)
  constant.4 = f32[] constant(1)
  broadcast.3 = f32[5000,5000]{1,0} broadcast(constant.4), dimensions={}
  ROOT add.4 = f32[5000,5000]{1,0} add(param_0.1, broadcast.3)
}

parallel_fusion {
  p = f32[5000,5000]{1,0} parameter(0)
  ROOT fusion.clone = f32[5000,5000]{1,0} fusion(p), kind=kLoop, calls=fused_computation.clone, outer_dimension_partitions={2}
}

ENTRY main.10 {
  Arg_0.1 = f32[5000,5000]{1,0} parameter(0)
  Arg_1.2 = f32[5000,5000]{1,0} parameter(1)
  dot.5 = f32[5000,5000]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
  call = f32[5000,5000]{1,0} call(dot.5), to_apply=parallel_fusion
  ROOT tuple.9 = (f32[5000,5000]{1,0}) tuple(call)
}




In [46]:
def render_pixel(uvt):
  uv = uvt[0:2]
  time = uvt[2]

  def vec3(x, y, z):
    return jnp.array([x, y, z])

  def mag(v):
    return jnp.sqrt(v.dot(v))

  def normalize(v):
    return v / mag(v)

  def sdf_sphere(v, r):
    return mag(v) - r

  def scene(p):
    return sdf_sphere(p, 0.6)

  sdf_normal = grad(scene)
  ray_dir = normalize(vec3(*uv, 1.0))
  ray_pos = vec3(0.0, 0.0, -2.0)
  eps = 0.001
  max_dist = 100
  bg_color = vec3(0.0, 0.0, 0.0)

  def raymarch_step(d):
    d_s = scene(ray_pos + ray_dir * d[0] * 0.95)
    return d[0] + d_s, d_s

  run_condition = lambda d : jnp.where((d[1] > eps) & (d[1] < max_dist), True, False)

  dist_i = scene(ray_pos)
  dist_o, dist_s = jax.lax.while_loop(run_condition, raymarch_step, (dist_i, dist_i))

  intersect_mask = (dist_s < eps).astype(jnp.float32)

  def color(p):
    light_dir = normalize(vec3(0.2,-0.4,-1.0))
    normal = sdf_normal(p)
    light = jnp.maximum(normal.dot(light_dir), 0.0)
    return jnp.array([0.8, 0.1, 0.15]) * light

  intersect_pos = ray_pos + ray_dir * dist_o
  obj_color = color(intersect_pos)

  return intersect_mask * obj_color + (1.0 - intersect_mask) * bg_color

In [49]:
img_res = 128
ax_coords = jnp.arange(img_res)/img_res-0.5
img_coords = jnp.meshgrid(ax_coords, ax_coords)
img_coords = jnp.stack([*img_coords, jnp.full((img_res, img_res), 1.0)], axis=2)
# x = jnp.array([-0.4, -0.4, -2.0])
# x.dot(x)
#render_pixel(jnp.array([0.0, 0.0, 1.0]))
render_pixels = jit(vmap(vmap(render_pixel)))
#test_img = render_pixels(img_coords)

In [61]:
#print(get_optimized_func(render_pixel, (jnp.array([0.0, 0.0, 1.0]),) ))

In [62]:
#print(get_optimized_func(jit(render_pixel), (jnp.array([0.0, 0.0, 1.0]),) ))

In [63]:
print(get_optimized_func(render_pixels, (img_coords,) ))

HloModule xla_computation_render_pixel.84, entry_computation_layout={(f32[128,128,3]{2,1,0})->(f32[128,128,3]{2,1,0})}

fused_computation.1 {
  param_0.4 = f32[128,128]{1,0} parameter(0)
  constant.68 = f32[] constant(0.001)
  broadcast.110 = f32[128,128]{1,0} broadcast(constant.68), dimensions={}
  compare.16 = pred[128,128]{1,0} compare(param_0.4, broadcast.110), direction=GT
  constant.67 = f32[] constant(100)
  broadcast.109 = f32[128,128]{1,0} broadcast(constant.67), dimensions={}
  compare.15 = pred[128,128]{1,0} compare(param_0.4, broadcast.109), direction=LT
  and.7 = pred[128,128]{1,0} and(compare.16, compare.15)
  constant.66 = pred[] constant(true)
  broadcast.108 = pred[128,128]{1,0} broadcast(constant.66), dimensions={}
  constant.65 = pred[] constant(false)
  broadcast.107 = pred[128,128]{1,0} broadcast(constant.65), dimensions={}
  select.17 = pred[128,128]{1,0} select(and.7, broadcast.108, broadcast.107)
  param_1.8 = f32[128,128]{1,0} parameter(1)
  constant.43.clone.4