# 5.4. XLA and PJIT

> **Author**: Gustavo Leite / **Date**: March 2022.

In this lecture we will take a look at the XLA compiler that compiles JAX functions. XLA is also developed by Google and lives in the Tensorflow repository.

In [1]:
import os
import sys
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np

from jax import jit, grad, make_jaxpr, random

In [2]:
# XLA will see 2 logical CPUs instead of one.
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
jax.config.update('jax_platform_name', 'cpu')

In [3]:
jax.devices()

[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7)]

<hr />

## XLA Compiler and IR

<center>
    <br />
    <img src="images/flows.png" width="80%" />
</center>

In [4]:
def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))

### JAXPRs

In [6]:
make_jaxpr(sigmoid)(1.0)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:f32[][39m = neg a
    c[35m:f32[][39m = exp b
    d[35m:f32[][39m = add c 1.0
    e[35m:f32[][39m = div 1.0 d
  [34m[22m[1min [39m[22m[22m(e,) }

In [7]:
X = jnp.ones(10)
make_jaxpr(sigmoid)(X)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[10][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:f32[10][39m = neg a
    c[35m:f32[10][39m = exp b
    d[35m:f32[10][39m = add c 1.0
    e[35m:f32[10][39m = div 1.0 d
  [34m[22m[1min [39m[22m[22m(e,) }

In [8]:
make_jaxpr(jit(sigmoid))(1.0)
#          ============

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:f32[][39m = xla_call[
      call_jaxpr={ [34m[22m[1mlambda [39m[22m[22m; c[35m:f32[][39m. [34m[22m[1mlet
          [39m[22m[22md[35m:f32[][39m = neg c
          e[35m:f32[][39m = exp d
          f[35m:f32[][39m = add e 1.0
          g[35m:f32[][39m = div 1.0 f
        [34m[22m[1min [39m[22m[22m(g,) }
      name=sigmoid
    ] a
  [34m[22m[1min [39m[22m[22m(b,) }

In [13]:
make_jaxpr(grad(sigmoid))(1.0)
#          =============

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:f32[][39m = neg a
    c[35m:f32[][39m = exp b
    d[35m:f32[][39m = add c 1.0
    _[35m:f32[][39m = div 1.0 d
    e[35m:f32[][39m = integer_pow[y=-2] d
    f[35m:f32[][39m = mul 1.0 e
    g[35m:f32[][39m = mul f 1.0
    h[35m:f32[][39m = neg g
    i[35m:f32[][39m = mul h c
    j[35m:f32[][39m = neg i
  [34m[22m[1min [39m[22m[22m(j,) }

### HLO IR

In [18]:
sigmoid_jit = jit(sigmoid)
x = np.ones(10)

ir = sigmoid_jit.lower(x).compiler_ir('hlo')
#                =====      ===========

print(f"Type = {type(ir)}")
print("=" * 70)
print(ir.as_hlo_text())

Type = <class 'jaxlib.xla_extension.XlaComputation'>
HloModule jit_sigmoid.2

ENTRY main.8 {
  constant.2 = f32[] constant(1)
  broadcast.3 = f32[10]{0} broadcast(constant.2), dimensions={}
  Arg_0.1 = f32[10]{0} parameter(0)
  negate.4 = f32[10]{0} negate(Arg_0.1)
  exponential.5 = f32[10]{0} exponential(negate.4)
  add.6 = f32[10]{0} add(exponential.5, broadcast.3)
  ROOT divide.7 = f32[10]{0} divide(broadcast.3, add.6)
}




In [17]:
try:
    ir = grad(sigmoid).lower(1.0).compiler_ir('hlo')
except AttributeError as error:
    print(f"Error: {error}", file=sys.stderr)
    print(f" Note: Lowering is only available when the function is compiled!", file=sys.stderr)

Error: 'function' object has no attribute 'lower'
 Note: Lowering is only available when the function is compiled!


In [19]:
ir = jit(grad(sigmoid)).lower(1.0).compiler_ir('hlo')
#    ==================

print(f"Type = {type(ir)}")
print("=" * 70)
print(ir.as_hlo_text())

Type = <class 'jaxlib.xla_extension.XlaComputation'>
HloModule jit_sigmoid.3

ENTRY main.11 {
  constant.2 = f32[] constant(1)
  Arg_0.1 = f32[] parameter(0)
  negate.3 = f32[] negate(Arg_0.1)
  exponential.4 = f32[] exponential(negate.3)
  add.5 = f32[] add(exponential.4, constant.2)
  multiply.6 = f32[] multiply(add.5, add.5)
  divide.7 = f32[] divide(constant.2, multiply.6)
  negate.8 = f32[] negate(divide.7)
  multiply.9 = f32[] multiply(negate.8, exponential.4)
  ROOT negate.10 = f32[] negate(multiply.9)
}




In [20]:
make_jaxpr(grad(sigmoid))(1.0)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:f32[][39m = neg a
    c[35m:f32[][39m = exp b
    d[35m:f32[][39m = add c 1.0
    _[35m:f32[][39m = div 1.0 d
    e[35m:f32[][39m = integer_pow[y=-2] d
    f[35m:f32[][39m = mul 1.0 e
    g[35m:f32[][39m = mul f 1.0
    h[35m:f32[][39m = neg g
    i[35m:f32[][39m = mul h c
    j[35m:f32[][39m = neg i
  [34m[22m[1min [39m[22m[22m(j,) }

### MHLO IR

In [22]:
x = np.ones(10)
ir = jit(sigmoid).lower(x).compiler_ir('mhlo')

print(f"Type = {type(ir)}")
print("=" * 70)
print(ir)

Type = <class 'jaxlib.mlir._mlir_libs._mlir.ir.Module'>
module @jit_sigmoid.5 {
  func public @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
    %0 = mhlo.negate %arg0 : tensor<10xf32>
    %1 = mhlo.exponential %0 : tensor<10xf32>
    %2 = mhlo.constant dense<1.000000e+00> : tensor<f32>
    %3 = "mhlo.broadcast_in_dim"(%2) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<10xf32>
    %4 = mhlo.add %1, %3 : tensor<10xf32>
    %5 = mhlo.constant dense<1.000000e+00> : tensor<f32>
    %6 = "mhlo.broadcast_in_dim"(%5) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<10xf32>
    %7 = mhlo.divide %6, %4 : tensor<10xf32>
    return %7 : tensor<10xf32>
  }
}



In [23]:
ir = jit(grad(sigmoid)).lower(1.0).compiler_ir('mhlo')

print(f"Type = {type(ir)}")
print("=" * 70)
print(ir)

Type = <class 'jaxlib.mlir._mlir_libs._mlir.ir.Module'>
module @jit_sigmoid.6 {
  func public @main(%arg0: tensor<f32>) -> tensor<f32> {
    %0 = mhlo.negate %arg0 : tensor<f32>
    %1 = mhlo.exponential %0 : tensor<f32>
    %2 = mhlo.constant dense<1.000000e+00> : tensor<f32>
    %3 = mhlo.add %1, %2 : tensor<f32>
    %4 = mhlo.constant dense<1.000000e+00> : tensor<f32>
    %5 = mhlo.divide %4, %3 : tensor<f32>
    %6 = mhlo.multiply %3, %3 : tensor<f32>
    %7 = mhlo.constant dense<1.000000e+00> : tensor<f32>
    %8 = mhlo.divide %7, %6 : tensor<f32>
    %9 = mhlo.constant dense<1.000000e+00> : tensor<f32>
    %10 = mhlo.multiply %9, %8 : tensor<f32>
    %11 = mhlo.constant dense<1.000000e+00> : tensor<f32>
    %12 = mhlo.multiply %10, %11 : tensor<f32>
    %13 = mhlo.negate %12 : tensor<f32>
    %14 = mhlo.multiply %13, %1 : tensor<f32>
    %15 = mhlo.negate %14 : tensor<f32>
    return %15 : tensor<f32>
  }
}



In [24]:
ir = jit(sigmoid).lower(jnp.eye(2)).compiler_ir('mhlo')
#                      ==========

print(f"Type = {type(ir)}")
print("=" * 70)
print(ir)

Type = <class 'jaxlib.mlir._mlir_libs._mlir.ir.Module'>
module @jit_sigmoid.12 {
  func public @main(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
    %0 = mhlo.negate %arg0 : tensor<2x2xf32>
    %1 = mhlo.exponential %0 : tensor<2x2xf32>
    %2 = mhlo.constant dense<1.000000e+00> : tensor<f32>
    %3 = "mhlo.broadcast_in_dim"(%2) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<2x2xf32>
    %4 = mhlo.add %1, %3 : tensor<2x2xf32>
    %5 = mhlo.constant dense<1.000000e+00> : tensor<f32>
    %6 = "mhlo.broadcast_in_dim"(%5) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<2x2xf32>
    %7 = mhlo.divide %6, %4 : tensor<2x2xf32>
    return %7 : tensor<2x2xf32>
  }
}



### Revisiting our CNN

In [25]:
class CNN(nn.Module):   
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        x = nn.log_softmax(x)
        return x

In [26]:
key    = random.PRNGKey(0)
image  = jnp.zeros((28, 28, 1))
model  = CNN()
params = model.init(key, image)
apply  = jit(model.apply)

ir = apply.lower(params, image).compiler_ir("hlo")
#                =============

print(ir.as_hlo_text())

HloModule jit_apply.63

relu.23 {
  Arg_0.24 = f32[28,28,32]{2,1,0} parameter(0)
  constant.25 = f32[] constant(0)
  broadcast.26 = f32[28,28,32]{2,1,0} broadcast(constant.25), dimensions={}
  ROOT maximum.27 = f32[28,28,32]{2,1,0} maximum(Arg_0.24, broadcast.26)
}

region_0.30 {
  Arg_0.31 = f32[] parameter(0)
  Arg_1.32 = f32[] parameter(1)
  ROOT add.33 = f32[] add(Arg_0.31, Arg_1.32)
}

relu_0.45 {
  Arg_0.46 = f32[14,14,64]{2,1,0} parameter(0)
  constant.47 = f32[] constant(0)
  broadcast.48 = f32[14,14,64]{2,1,0} broadcast(constant.47), dimensions={}
  ROOT maximum.49 = f32[14,14,64]{2,1,0} maximum(Arg_0.46, broadcast.48)
}

region_1.52 {
  Arg_0.53 = f32[] parameter(0)
  Arg_1.54 = f32[] parameter(1)
  ROOT add.55 = f32[] add(Arg_0.53, Arg_1.54)
}

relu_1.66 {
  Arg_0.67 = f32[7,256]{1,0} parameter(0)
  constant.68 = f32[] constant(0)
  broadcast.69 = f32[7,256]{1,0} broadcast(constant.68), dimensions={}
  ROOT maximum.70 = f32[7,256]{1,0} maximum(Arg_0.67, broadcast.69)
}

regi

<hr />

## Peeking under the hood

In [27]:
# Create a 2x2 matrix
M = jnp.zeros((2, 2))

# Create XLA module from the "dot" operator
computation = jax.xla_computation(jnp.dot)(M, M)

# Print HLO
print(type(computation))
print("=" * 70)
print(computation.as_hlo_text())

<class 'jaxlib.xla_extension.XlaComputation'>
HloModule xla_computation_dot.65

ENTRY main.5 {
  Arg_0.1 = f32[2,2]{1,0} parameter(0)
  Arg_1.2 = f32[2,2]{1,0} parameter(1)
  dot.3 = f32[2,2]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
  ROOT tuple.4 = (f32[2,2]{1,0}) tuple(dot.3)
}




### CPU

When compiling for the CPU, notice how the compiler simply used the `dot` operator.

In [28]:
cpu_backend = jax.lib.xla_bridge.get_backend('cpu')
executable = cpu_backend.compile(computation)
module = executable.hlo_modules()[0]

print(module.to_string())

HloModule xla_computation_dot.65

ENTRY %main.5 (Arg_0.1: f32[2,2], Arg_1.2: f32[2,2]) -> (f32[2,2]) {
  %Arg_0.1 = f32[2,2]{1,0} parameter(0)
  %Arg_1.2 = f32[2,2]{1,0} parameter(1)
  %dot.3 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %Arg_0.1, f32[2,2]{1,0} %Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="xla_computation(dot)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/tmp/ipykernel_4480/1430351131.py" source_line=5}
  ROOT %tuple.4 = (f32[2,2]{1,0}) tuple(f32[2,2]{1,0} %dot.3)
}




### GPU

However, when compiling for the GPU, the compiler delegated the work to the CuBLAS library.

In [29]:
gpu_backend = jax.lib.xla_bridge.get_backend('gpu')
executable = gpu_backend.compile(computation)
module = executable.hlo_modules()[0]

print(module.to_string())

HloModule xla_computation_dot.65

ENTRY %main.5 (Arg_0.1: f32[2,2], Arg_1.2: f32[2,2]) -> (f32[2,2]) {
  %Arg_0.1 = f32[2,2]{1,0} parameter(0)
  %Arg_1.2 = f32[2,2]{1,0} parameter(1)
  %cublas-gemm.1 = f32[2,2]{1,0} custom-call(f32[2,2]{1,0} %Arg_0.1, f32[2,2]{1,0} %Arg_1.2), custom_call_target="__cublas$gemm", metadata={op_name="xla_computation(dot)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/tmp/ipykernel_4480/1430351131.py" source_line=5}, backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"selected_algorithm\":\"9\"}"
  ROOT %tuple.4 = (f32[2,2]{1,0}) tuple(f32[2,2]{1,0} %cublas-gemm.1)
}




2022-06-29 08:50:23.664881: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_helper.cc:56] Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice. This may result in compilation or runtime failures, if the program we try to run uses routines from libdevice.
Searched for CUDA in the following directories:
  ./cuda_sdk_lib
  /usr/local/cuda-11.4
  /usr/local/cuda
  .
You can choose the search directory by setting xla_gpu_cuda_data_dir in HloModule's DebugOptions.  For most apps, setting the environment variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work.


In [32]:
computation = jax.xla_computation(jit(grad(sigmoid)))(1.0)

gpu_backend = jax.lib.xla_bridge.get_backend('cpu')
executable = gpu_backend.compile(computation)
module = executable.hlo_modules()[0]

print(module.to_string())

HloModule xla_computation_sigmoid.67

%fused_computation (param_0.4: f32[]) -> f32[] {
  %constant.1 = f32[] constant(1)
  %param_0.4 = f32[] parameter(0)
  %negate.6 = f32[] negate(f32[] %param_0.4), metadata={op_name="xla_computation(sigmoid)/jit(main)/jit(sigmoid)/neg" source_file="/tmp/ipykernel_4480/1433209252.py" source_line=2}
  %exponential.1 = f32[] exponential(f32[] %negate.6), metadata={op_name="xla_computation(sigmoid)/jit(main)/jit(sigmoid)/exp" source_file="/tmp/ipykernel_4480/1433209252.py" source_line=2}
  %add.1 = f32[] add(f32[] %exponential.1, f32[] %constant.1), metadata={op_name="xla_computation(sigmoid)/jit(main)/jit(sigmoid)/add" source_file="/tmp/ipykernel_4480/1433209252.py" source_line=2}
  %multiply.3 = f32[] multiply(f32[] %add.1, f32[] %add.1), metadata={op_name="xla_computation(sigmoid)/jit(main)/jit(sigmoid)/mul" source_file="/tmp/ipykernel_4480/1433209252.py" source_line=2}
  %divide.1 = f32[] divide(f32[] %constant.1, f32[] %multiply.3), metadata={op_na

<hr />

## Advanced Partitioning

In this section we will look how JAX enables data partitioning across devices.

In [33]:
devices = jax.devices()
devices

[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7)]

<hr>

### Revisiting PMAP

In [34]:
data = np.ones((8, 4))
data

array([[1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.]])

In [35]:
result = jax.pmap(jnp.sum, in_axes=0)(data)

print(f"Result = {result!r}\n")
for i, buffer in enumerate(result.device_buffers):
    print(f"Device {i}: {buffer!r}")

Result = ShardedDeviceArray([4., 4., 4., 4., 4., 4., 4., 4.], dtype=float32)

Device 0: DeviceArray(4., dtype=float32)
Device 1: DeviceArray(4., dtype=float32)
Device 2: DeviceArray(4., dtype=float32)
Device 3: DeviceArray(4., dtype=float32)
Device 4: DeviceArray(4., dtype=float32)
Device 5: DeviceArray(4., dtype=float32)
Device 6: DeviceArray(4., dtype=float32)
Device 7: DeviceArray(4., dtype=float32)


In [36]:
result = jax.pmap(jnp.sum, in_axes=1)(data)

print(f"Result = {result!r}\n")
for i, buffer in enumerate(result.device_buffers):
    print(f"Device {i}: {buffer!r}")

Result = ShardedDeviceArray([8., 8., 8., 8.], dtype=float32)

Device 0: DeviceArray(8., dtype=float32)
Device 1: DeviceArray(8., dtype=float32)
Device 2: DeviceArray(8., dtype=float32)
Device 3: DeviceArray(8., dtype=float32)


<hr>

### Getting Started with PJIT

In [37]:
from jax.experimental import PartitionSpec
from jax.experimental.maps import Mesh
from jax.experimental.pjit import pjit

In [38]:
device_array = np.asarray(devices).reshape((4, 2))
device_array

array([[CpuDevice(id=0), CpuDevice(id=1)],
       [CpuDevice(id=2), CpuDevice(id=3)],
       [CpuDevice(id=4), CpuDevice(id=5)],
       [CpuDevice(id=6), CpuDevice(id=7)]], dtype=object)

In [39]:
# Create a device mesh
# ------------------------------------------
device_mesh = Mesh(device_array, ("x", "y"))

print("DEVICE MESH:")
print(device_mesh)
print()

# Create 8x4 matrix
# ------------------------------------------
data = jnp.arange(8 * 4).reshape(8, 4)

print("DATA ARRAY:")
print(data)

DEVICE MESH:
Mesh(array([[0, 1],
       [2, 3],
       [4, 5],
       [6, 7]]), ('x', 'y'))

DATA ARRAY:
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]
 [12 13 14 15]
 [16 17 18 19]
 [20 21 22 23]
 [24 25 26 27]
 [28 29 30 31]]


<center>
    <img src="images/pjit_data_devices.png" width="50%" />
</center>

In [40]:
fn = pjit(
        lambda x: x,                               # The function to be transformed
        in_axis_resources=None,                    # How the inputs are partitioned
        out_axis_resources=PartitionSpec("x", "y") # How the outputs are partitioned
)

In [41]:
with device_mesh:
    output = fn(data)
    
print(f"Type = {type(output).__name__!r}")
print(output)
print("-" * 80)

for i, buffer in enumerate(output.device_buffers):
    print(f"Device {i} has buffer:\n{buffer}\n")

Type = 'ShardedDeviceArray'
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]
 [12 13 14 15]
 [16 17 18 19]
 [20 21 22 23]
 [24 25 26 27]
 [28 29 30 31]]
--------------------------------------------------------------------------------
Device 0 has buffer:
[[0 1]
 [4 5]]

Device 1 has buffer:
[[2 3]
 [6 7]]

Device 2 has buffer:
[[ 8  9]
 [12 13]]

Device 3 has buffer:
[[10 11]
 [14 15]]

Device 4 has buffer:
[[16 17]
 [20 21]]

Device 5 has buffer:
[[18 19]
 [22 23]]

Device 6 has buffer:
[[24 25]
 [28 29]]

Device 7 has buffer:
[[26 27]
 [30 31]]



<center>
    <img src="images/pjit_partitioning.png" width="100%" />
</center>

### What does PJIT HLO look like?

In [43]:
def square(x):
    return x * x

fn = pjit(
        square,
        in_axis_resources=PartitionSpec("x", "y"),
        out_axis_resources=PartitionSpec("x", "y")
)

with device_mesh:
    ir = fn.lower(data).compiler_ir('hlo')

print(ir.as_hlo_text())

HloModule pjit_square.75

ENTRY main.5 {
  Arg_0.1 = s32[8,4]{1,0} parameter(0), sharding={devices=[4,2]0,1,2,3,4,5,6,7}
  multiply.2 = s32[8,4]{1,0} multiply(Arg_0.1, Arg_0.1)
  tuple.3 = (s32[8,4]{1,0}) tuple(multiply.2)
  ROOT get-tuple-element.4 = s32[8,4]{1,0} get-tuple-element(tuple.3), index=0, sharding={devices=[4,2]0,1,2,3,4,5,6,7}
}




### Functions with multiple parameters

In [44]:
M = jnp.eye(8)
v = jnp.arange(8).reshape((8, 1))

print(M)
print()
print(v)

[[1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1.]]

[[0]
 [1]
 [2]
 [3]
 [4]
 [5]
 [6]
 [7]]


In [45]:
spec = PartitionSpec(("x", "y"), None)

f = pjit(jnp.dot,
         in_axis_resources=(spec, None),
         out_axis_resources=spec)

with device_mesh:
    ir = f.lower(M, v).compiler_ir("hlo")
    output = f(M, v)
    
print(ir.as_hlo_text())
for i, buffer in enumerate(output.device_buffers):
    print(i, buffer)

HloModule pjit_dot.83

ENTRY main.7 {
  Arg_0.1 = f32[8,8]{1,0} parameter(0), sharding={devices=[8,1]0,1,2,3,4,5,6,7}
  Arg_1.2 = s32[8,1]{1,0} parameter(1), sharding={replicated}
  convert.3 = f32[8,1]{1,0} convert(Arg_1.2)
  dot.4 = f32[8,1]{1,0} dot(Arg_0.1, convert.3), lhs_contracting_dims={1}, rhs_contracting_dims={0}
  tuple.5 = (f32[8,1]{1,0}) tuple(dot.4)
  ROOT get-tuple-element.6 = f32[8,1]{1,0} get-tuple-element(tuple.5), index=0, sharding={devices=[8,1]0,1,2,3,4,5,6,7}
}


0 [[0.]]
1 [[1.]]
2 [[2.]]
3 [[3.]]
4 [[4.]]
5 [[5.]]
6 [[6.]]
7 [[7.]]


<hr />

<center style="font-size: 14pt; font-weight: bold;">
    That's all folks!
</center>