# Distributed Sharding and Visualization in JAX
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/ASKabalan/Tutorials/blob/main/Cophy2024/Exercises/01_MultiDevice_With_JAX.ipynb)

## Overview
This notebook demonstrates how to use JAX’s sharding features with distributed computing across multiple devices. We explore various partitioning schemes and visually inspect the data distribution.


### Installing `rich` Library for Sharding Visualization

The `rich` library is required to visualize the sharding of JAX arrays.

In [11]:
!pip install -q rich

  pid, fd = os.forkpty()


### Configuring JAX for TPU or CPU Execution

First, we attempt to configure JAX to run on a TPU if one is available. If TPU setup fails, we default to using the CPU by setting specific environment variables. These settings allow JAX to recognize `"cpu"` as the platform and create 8 logical CPU devices for data sharding.


In [58]:
import os
try:
    import jax.tools.colab_tpu
    jax.tools.colab_tpu.setup_tpu()
except RuntimeError:
    os.environ["JAX_PLATFORM_NAME"] = "cpu"
    os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

### Verifying Device Configuration in JAX

Next, we import JAX and confirm that it recognizes 8 available devices (logical CPUs in this case). We also list these devices to ensure they are correctly initialized for sharding.


In [59]:
import jax
import jax.numpy as jnp
from jax import lax

assert jax.device_count() == 8

jax.devices()

[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7)]

### Defining Mesh and Sharding Specifications in JAX

In this section, we define a **mesh** and a **sharding specification** to control how data is partitioned across devices in JAX.

- **Mesh**: The `jax.make_mesh()` function creates a logical mesh, which is an arrangement of devices for parallel computation. Here, we define a 1D mesh with 8 devices along a single axis named `'x'`. This mesh will allow us to partition data across 8 logical CPUs.

- **PartitionSpec (`P`)**: The `PartitionSpec` object defines how data is split across each dimension in the mesh. In this case, `P('x')` specifies that data should be partitioned along the `'x'` axis. This allows JAX to map data onto the mesh according to this specification.

- **NamedSharding**: `NamedSharding(mesh, P('x'))` combines the mesh and partition specification to create a sharding configuration that will distribute data across the mesh in line with the specified partitioning scheme.

By defining the mesh and sharding in this way, we enable JAX to handle distributed data efficiently across devices according to our custom configuration.

For more details, refer to the [JAX sharding documentation](https://jax.readthedocs.io/en/latest/jax.sharding.html).


In [3]:
from jax.sharding import PartitionSpec as P , Mesh , NamedSharding
from jax.debug import visualize_array_sharding

mesh = jax.make_mesh((8,) , ('x'))
sharding = NamedSharding(mesh , P('x'))

### Creating and Sharding an Array in JAX

Here, we create a random 16x16 array and apply a sharding constraint to it using `lax.with_sharding_constraint`.

- **Applying Sharding with `with_sharding_constraint`**: The `lax.with_sharding_constraint` function allows us to apply our predefined sharding specification (`sharding`) directly to the array. Unlike `jax.device_put`, which transfers data to a specific device, `with_sharding_constraint` can be used within a `jit`-compiled function. This makes it more flexible for JIT-compiled and parallelized code, allowing dynamic sharding behavior based on the provided sharding configuration.

- **Visualizing Array Sharding**: Finally, we use `visualize_array_sharding()` to display the data distribution across devices, which provides a color-coded view of how the array is sharded based on our mesh and sharding specification.


In [4]:
a = jax.random.normal(jax.random.key(0) , (16 , 16))
a_sharded = lax.with_sharding_constraint(a , sharding)
visualize_array_sharding(a_sharded)

### Sharding on a Subset of Devices

In this example, we create a smaller mesh to shard the array across only a subset of available devices.

- **Half Mesh**: We define a `half_mesh` with only 4 devices arranged along a single axis named `'x'`. This is achieved by specifying `devices=jax.devices()[:4]`, which selects the first 4 logical devices.

- **New Sharding Specification**: `half_sharding` is created by applying the `half_mesh` to a `PartitionSpec` with `P('x')`. This sharding configuration will distribute data only across the 4 selected devices rather than all 8.

- **Applying the Sharding Constraint**: We use `lax.with_sharding_constraint` again to apply `half_sharding` to the array `a`, resulting in a new array `a_2` that is sharded across only the first 4 devices.

- **Visualizing Array Sharding**: Using `visualize_array_sharding(a_2)`, we can see how data is partitioned across these 4 devices, providing a clear visual representation of the restricted sharding setup.


In [16]:
half_mesh = jax.make_mesh((4,) , ('x') , devices=jax.devices()[:4])
half_sharding = NamedSharding(half_mesh , P('x'))
a_2 = lax.with_sharding_constraint(a , half_sharding)

visualize_array_sharding(a_2)

### Exercise 1: Find the correct sharding 

In these exercises,complete the missing code to achieve specific data distributions across devices. Each exercise provides an illustration of the desired output in the following cell.


**Instructions:**
- Update the `sharding` and or `mesh` configuration as specified below.
- If the mesh is not in the cell .. this suggest that you should reuse the mesh from the previous cell
- Run the cell to see if the output matches the required output shown.

#### Exercise 1.1 : Transposed sharding

In [None]:
transposed_sharding = ...
b = lax.with_sharding_constraint(a , transposed_sharding)
visualize_array_sharding(b)

Required output 

<pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6">         </span><span style="color: #000000; text-decoration-color: #000000; background-color: #e7cb94">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31">         </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6">         </span><span style="color: #000000; text-decoration-color: #000000; background-color: #e7cb94">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31">         </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6">         </span><span style="color: #000000; text-decoration-color: #000000; background-color: #e7cb94">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31">         </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6">         </span><span style="color: #000000; text-decoration-color: #000000; background-color: #e7cb94">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31">         </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6">         </span><span style="color: #000000; text-decoration-color: #000000; background-color: #e7cb94">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31">         </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79">  CPU 0  </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b">  CPU 1  </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252">  CPU 2  </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6">  CPU 3  </span><span style="color: #000000; text-decoration-color: #000000; background-color: #e7cb94">  CPU 4  </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf">  CPU 5  </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194">  CPU 6  </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31">  CPU 7  </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6">         </span><span style="color: #000000; text-decoration-color: #000000; background-color: #e7cb94">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31">         </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6">         </span><span style="color: #000000; text-decoration-color: #000000; background-color: #e7cb94">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31">         </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6">         </span><span style="color: #000000; text-decoration-color: #000000; background-color: #e7cb94">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31">         </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6">         </span><span style="color: #000000; text-decoration-color: #000000; background-color: #e7cb94">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31">         </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6">         </span><span style="color: #000000; text-decoration-color: #000000; background-color: #e7cb94">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31">         </span>
</pre>


#### Exercise 1.2 : 2D Sharding

In [None]:
mesh2d = ...
sharding2d = ...
c = lax.with_sharding_constraint(a , sharding2d)

visualize_array_sharding(c)

Required output 

<pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6">         </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6">         </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79">  CPU 0  </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b">  CPU 1  </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252">  CPU 2  </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6">  CPU 3  </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6">         </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6">         </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6">         </span>
<span style="color: #000000; text-decoration-color: #000000; background-color: #e7cb94">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31">         </span>
<span style="color: #000000; text-decoration-color: #000000; background-color: #e7cb94">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31">         </span>
<span style="color: #000000; text-decoration-color: #000000; background-color: #e7cb94">  CPU 4  </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf">  CPU 5  </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194">  CPU 6  </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31">  CPU 7  </span>
<span style="color: #000000; text-decoration-color: #000000; background-color: #e7cb94">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31">         </span>
<span style="color: #000000; text-decoration-color: #000000; background-color: #e7cb94">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31">         </span>
<span style="color: #000000; text-decoration-color: #000000; background-color: #e7cb94">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194">         </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31">         </span>
</pre>


#### Exercise 1.2 : Transposed 2D Sharding

In [None]:
sharding2d_T = ...
d = lax.with_sharding_constraint(a , sharding2d_T)

visualize_array_sharding(d)

Now

<pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79">            </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b">            </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79">   CPU 0    </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b">   CPU 4    </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79">            </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b">            </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252">            </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6">            </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252">   CPU 1    </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6">   CPU 5    </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252">            </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6">            </span>
<span style="color: #000000; text-decoration-color: #000000; background-color: #e7cb94">            </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf">            </span>
<span style="color: #000000; text-decoration-color: #000000; background-color: #e7cb94">   CPU 2    </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf">   CPU 6    </span>
<span style="color: #000000; text-decoration-color: #000000; background-color: #e7cb94">            </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf">            </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194">            </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31">            </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194">   CPU 3    </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31">   CPU 7    </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194">            </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31">            </span>
</pre>


#### Exercise 1.3 : Partial 2D mesh

In [None]:
mesh2d_partial = ...
sharding2d_partial = ...
e = lax.with_sharding_constraint(a , sharding2d_partial)

visualize_array_sharding(e)

<pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79">            </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6">            </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79">            </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6">            </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79">   CPU 0    </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6">   CPU 2    </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79">            </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6">            </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79">            </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6">            </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79">            </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6">            </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a">            </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b">            </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a">            </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b">            </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a">   CPU 4    </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b">   CPU 6    </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a">            </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b">            </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a">            </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b">            </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a">            </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b">            </span>
</pre>


#### Exercise 1.4 : Mysterious mesh and sharding

In [None]:
mysterious_mesh = ...
mysterious_sharding = ...
f = lax.with_sharding_constraint(a , mysterious_sharding)

visualize_array_sharding(f)

<pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79">            </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b">            </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79">   CPU 0    </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b">   CPU 2    </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79">            </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b">            </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252">            </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6">            </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252">   CPU 4    </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6">   CPU 6    </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252">            </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6">            </span>
<span style="color: #000000; text-decoration-color: #000000; background-color: #e7cb94">            </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf">            </span>
<span style="color: #000000; text-decoration-color: #000000; background-color: #e7cb94">   CPU 1    </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf">   CPU 3    </span>
<span style="color: #000000; text-decoration-color: #000000; background-color: #e7cb94">            </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf">            </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194">            </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31">            </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194">   CPU 5    </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31">   CPU 7    </span>
<span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194">            </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31">            </span>
</pre>


### Exercise 2: Batching on Parallel Devices

In the next example, we will explore how batching works across parallel devices and examine the performance benefits of distributing tasks over multiple devices.


### Batching Linear Solvers Across Devices

In this example, we have 8 separate $A$ matrices and 8 $b$ vectors.

We will demonstrate how to run each linear solver `jnp.linalg.solve` solving $Ax = b$ on a different device by distributing the work across the devices.

In [None]:
from functools import partial
from jax.experimental.shard_map import shard_map
from jax.experimental.multihost_utils import process_allgather

As = [jax.random.normal(jax.random.PRNGKey(i), (16, 16)) for i in range(8)]
bs = [jax.random.normal(jax.random.PRNGKey(i + 8), (16,)) for i in range(8)]

mesh = ...
sharding = ...

As_sharded = lax.with_sharding_constraint(jnp.stack(As) , sharding)
bs_sharded = lax.with_sharding_constraint(jnp.stack(bs) , sharding)

@partial(shard_map, mesh=mesh, in_specs=...,out_specs=...)
def batch_solvers(A, b):
    pass

x_sharded = batch_solvers(As_sharded, bs_sharded)

visualize_array_sharding(x_sharded)

x_s = process_allgather(x_sharded)

b_hat = [jnp.dot(A, x) for A , x in zip(As , x_s)]

assert all(jnp.allclose(b_hat[i], bs[i] , rtol=1e-4 , atol=1e-4) for i in range(8))

### Exercise 3: Matrix multiplication 

In this exercise, we will explore distributed matrix multiplication using JAX.
When we run the dot product algorithm (@) on arrays allocated on the first device, the result will also be stored on the first device by default.

In [None]:
A = jax.random.normal(jax.random.PRNGKey(0) , (16, 16))
B = jax.random.normal(jax.random.PRNGKey(1) , (16, 16))
visualize_array_sharding(A @ B )

### A and B Partitioned

In this example:  
- **`A` is row-partitioned** across the mesh's `x` axis.  
- **`B` is column-partitioned** across the mesh's `y` axis.  
- The result of `A.dot(B)` is **block-partitioned**, distributed across the devices.  

In [None]:
mesh = jax.make_mesh((4, 2), ('x', 'y'))

a_sharding = NamedSharding(mesh , P('x'))
b_sharding = NamedSharding(mesh , P(None , 'y'))
fully_replicated = NamedSharding(mesh , P())

A_sharded = lax.with_sharding_constraint(A , a_sharding)
B_sharded = lax.with_sharding_constraint(B , b_sharding)

visualize_array_sharding(A_sharded)
visualize_array_sharding(B_sharded)
visualize_array_sharding(A_sharded.dot(B_sharded))

### A Partitioned, B Replicated  

In this example:  
- **`A` is row-partitioned** across the mesh's `x` axis.  
- **`B` is fully replicated** across all devices.  
- The result of `A.dot(B)` inherits the row-partitioning of `A`.


In [None]:
A_sharded = lax.with_sharding_constraint(A , a_sharding)
B_replicated = lax.with_sharding_constraint(B , fully_replicated)

visualize_array_sharding(A_sharded)
visualize_array_sharding(B_replicated)
visualize_array_sharding(A_sharded.dot(B_replicated))

### A Replicated, B Partitioned  

In this example:  
- **`A` is fully replicated** across all devices.  
- **`B` is column-partitioned** along the mesh's `y` axis.  
- The result of `A.dot(B)` inherits the column-partitioning of `B`.

In [None]:
A_replicated = lax.with_sharding_constraint(A , fully_replicated)
B_sharded = lax.with_sharding_constraint(B , b_sharding)

visualize_array_sharding(A_replicated)
visualize_array_sharding(B_sharded)
visualize_array_sharding(A_replicated.dot(B_sharded))

### Exercise 3.1: Matrix Multiplication Using Shardmap

In this exercise, you will manually perform matrix multiplication for the following three partitioning cases using **Shardmap**:

1. **A Row, B Column Partitioned**
2. **A Row Partitioned**
3. **B Column Partitioned**



In [None]:
from functools import partial
from jax.experimental.shard_map import shard_map
@partial(shard_map, mesh=..., in_specs=...,out_specs=...)
def row_col_matmul(A , B):
    pass

@partial(shard_map, mesh=..., in_specs=...,out_specs=...)
def row_replicated_matmul(A , B):
    pass

@partial(shard_map, mesh=..., in_specs=...,out_specs=...)
def replicated_col_matmul(A , B):
    pass


# Check row col partitionning
C_row_col = row_col_matmul(A_sharded , B_sharded)
assert(jnp.allclose(C_row_col , A_sharded@ B_sharded))
assert((A_sharded@ B_sharded).sharding == C_row_col.sharding)
# Check row replicated partitionning
C_row_replicated = row_replicated_matmul(A_sharded , B_replicated)
assert(jnp.allclose(C_row_replicated , A_sharded@ B_replicated))
assert((A_sharded@ B_replicated).sharding == C_row_replicated.sharding)
# Check replicated col partitionning
C_replicated_col = replicated_col_matmul(A_replicated , B_sharded)
assert(jnp.allclose(C_replicated_col , A_replicated@ B_sharded))
assert((A_replicated@ B_sharded).sharding == C_replicated_col.sharding)


### Exercise 3.2: Matrix Multiplication Strategies with Communication Involved

In this exercise, we explore matrix multiplication strategies that involve communication between different partitions. Unlike the previous strategies, which didn't require communication, these strategies will necessitate the exchange of data across partitions.

The strategies to consider are:

1. **Block Partitioned Multiplied by Row or Column Partitioned**
2. **Block Partitioned Multiplied by Replicated Partitions**

For each strategy, ensure you account for the communication overhead involved when performing the matrix multiplication.

In [None]:
block_sharding = NamedSharding(mesh , P('x' , 'y'))
row_sharding = NamedSharding(mesh , P('y' , None))

A_block_sharded = lax.with_sharding_constraint(A , block_sharding)
B_row_sharded = lax.with_sharding_constraint(B , row_sharding)

visualize_array_sharding(A_block_sharded)
visualize_array_sharding(B_row_sharded)
visualize_array_sharding(A_block_sharded.dot(B_row_sharded))

Just like in the previous exercise, write the matrix multiplication and the required collective operations for each of the following strategies:

**Block Partitioned × Row or Column Partitioned**

In [None]:
from functools import partial
from jax.experimental.shard_map import shard_map

@partial(shard_map, mesh=..., in_specs=...,out_specs=...)
def A_block_matmul(A , B):
    pass


# Check A block matmul
C_A_Block = A_block_matmul(A_block_sharded , B_row_sharded)
assert(jnp.allclose(C_A_Block , A_block_sharded@ B_row_sharded))
assert((A_block_sharded@ B_row_sharded).sharding == C_A_Block.sharding)



### Exercise 3.3: Best Strategy to Multiply 4 Matrices Together with the Least Amount of Collectives

To multiply 4 matrices $A$, $B$, $C$, and $D$ together with the least amount of collective operations, we need to consider how matrix multiplication and partitioning work in a distributed system. The goal is to minimize communication overhead by optimizing how matrices are partitioned and multiplied. The multiplication expression is:

$$
E = A \times B \times C \times D
$$


#### Bonus : 

In an iterative algorithme .. is there a way to avoid using a collective?


In [None]:
from jax.experimental.multihost_utils import process_allgather

A = jax.random.normal(jax.random.PRNGKey(0) , (16, 16))
B = jax.random.normal(jax.random.PRNGKey(1) , (16, 16))
C = jax.random.normal(jax.random.PRNGKey(2) , (16, 16))
D = jax.random.normal(jax.random.PRNGKey(3) , (16, 16))

A_sharded = lax.with_sharding_constraint(A , ...)
B_sharded = lax.with_sharding_constraint(B , ...)
C_sharded = lax.with_sharding_constraint(C , ...)
D_sharded = lax.with_sharding_constraint(D , ...)

AB = A_sharded @ B_sharded
ABC = AB @ C_sharded
ABCD = ABC @ D_sharded

jnp.allclose(process_allgather(ABCD , tiled=True) , A @ B @ C @ D)

Array(True, dtype=bool)

### Example : Distributing a Neural Network

To distribute a neural network using Flax, the model’s parameters need to be sharded properly across devices. This involves partitioning the model’s weights and activations and using JAX's `with_sharding_constraint` to ensure efficient distribution. The approach relies on creating a device mesh and applying sharding strategies (e.g., data parallelism or model parallelism) to scale training across multiple devices. This method improves memory management and allows training of large models in parallel.

For more details, check the [Flax GSPMD guide](https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html#initialize-a-sharded-model).

### [BONUS] Exercise 4: Using Complex Operations on Distributed Arrays

In this section, we will explore the effects of using more advanced algorithms on distributed arrays, specifically focusing on operations that are not easily parallelizable. This includes operations such as:

- **Fast Fourier Transforms (FFT)**
- **Linear Solvers**

We will investigate how these complex operations behave when applied to distributed arrays, and analyze the challenges and optimizations involved in parallelizing them effectively.



#### Distributed Fast Fourier Transform

In this section, we will test the effects of applying a Fast Fourier Transform (FFT) on a distributed 3D volume and analyze the behavior and performance challenges associated with distributing this operation.

In [5]:
from jax import  numpy as jnp
from jax.experimental.multihost_utils import process_allgather

mesh = jax.make_mesh((4, 2), ('x', 'y'))
sharding = NamedSharding(mesh , P('x', 'y'))

A = jax.random.normal(jax.random.PRNGKey(0) , (16, 16 , 16))
A_sharded = lax.with_sharding_constraint(A , sharding)

local_fft = jnp.fft.fftn(A)
global_fft = jnp.fft.fftn(A_sharded)

jnp.allclose(process_allgather(global_fft , tiled=True) , local_fft)


Array(True, dtype=bool)

Great! It worked like a charm.  
Now, let's visualize the sharding of the output.

**Note**: You can extract slices from a globally distributed array, matrix, or volume. This functionality works even in a multi-host setup.

In [6]:
visualize_array_sharding(A_sharded[...,0])
visualize_array_sharding(global_fft[...,0])

This doesn't look right.

With embarrassingly parallel functions, the output matches the input.  
With a matrix multiplication (matmul), the output follows a set of rules.

Here, the output appears to be replicated?

To understand what’s happening, we will look for an **all-gather** operation.

In [7]:
compiled_hlo = jax.jit(jnp.fft.fftn).lower(A).compile().as_text()
print(f"Running a FFT on a local array does an all gather : {'all-gather' in compiled_hlo}")
compiled_hlo = jax.jit(jnp.fft.fftn).lower(A_sharded).compile().as_text()
print(f"Running a FFT on a sharded array does an all gather : {'all-gather' in compiled_hlo}")

Running a FFT on a local array does an all gather : False
Running a FFT on a sharded array does an all gather : True


After printing the compiled code:

OK, we now see that JAX does not know how to perform distributed FFTs.

So, it gathered everything on a single device and outputted a replicated result,  
which is obviously not what we want.

Using **jaxdecomp**:

**jaxdecomp** implements correctly distributed 3D FFTs, and they work similarly to how **jnp.dot** works.  
They support multiple sharding strategies (1D or 2D),  
except sharding the last dimension (so 3D sharding is not supported).

**jaxdecomp** gives the same result but with a transpose

In [8]:
!pip install --quiet jaxdecomp

  pid, fd = os.forkpty()


In [39]:
import jaxdecomp as jd

jd_global_fft = jd.fft.pfft3d(A_sharded)

jnp.allclose(process_allgather(jd_global_fft , tiled=True) , local_fft.transpose(1,2,0), atol=1e-5 , rtol=1e-5)

Array(True, dtype=bool)

In [10]:
visualize_array_sharding(jd_global_fft[...,0])

In [11]:
text = jax.jit(jd.fft.pfft3d).lower(A_sharded).compile().as_text()
print(f"Running a FFT on a sharded array does an all gather : {'all-gather' in text}")

Running a FFT on a sharded array does an all gather : False


In [41]:
mesh = jax.make_mesh((1,8), ('x','y'))
sharding = NamedSharding(mesh , P(None , 'y'))

A = jax.random.normal(jax.random.PRNGKey(0) , (16, 16 , 16))
A_sharded = lax.with_sharding_constraint(A , sharding)

local_fft = jnp.fft.fftn(A)
jd_global_fft = jd.fft.pfft3d(A_sharded)

jnp.allclose(process_allgather(jd_global_fft , tiled=True) , local_fft.transpose(1 , 2 , 0) , atol=1e-5 , rtol=1e-5)

Array(True, dtype=bool)

In [42]:
visualize_array_sharding(jd_global_fft[...,0])

#### Distributed Linear Solver

In Exercise 2, we saw how to batch linear solvers.  
Now, let's try running a linear solver on a distributed matrix.

In [47]:
A = jax.random.normal(jax.random.PRNGKey(0) , (16, 16))
               
b = jax.random.normal(jax.random.PRNGKey(1) , (16,))

x = jnp.linalg.solve(A, b)

assert jnp.allclose(jnp.dot(A, x), b , atol=1e-5 , rtol=1e-5)

Let’s now try running it on the distributed matrix.

In [57]:
mesh = jax.make_mesh((8, 1), ('x', 'y'))
sharding = NamedSharding(mesh , P('x', None))

A = lax.with_sharding_constraint(A , sharding)
visualize_array_sharding(A)

In [49]:
x = jnp.linalg.solve(A, b)

assert jnp.allclose(jnp.dot(A, x), b , atol=1e-5 , rtol=1e-5)

In [50]:
visualize_array_sharding(x)

After visualizing:

As expected, the output $x$ is replicated,  
which means that JAX does not know how to handle the distributed matrix for this operation.  

Let’s print the solver's JIT code and see where the problem lies.

In [53]:
text = jax.jit(jnp.linalg.solve).lower(A, b).compile().as_text()
print(f"Does the linear solve do an all gather : {'all-gather' in text}")
text = jax.jit(jax.lax.linalg.lu).lower(A).compile().as_text()
print(f"Does the LU decomposition do an all gather : {'all-gather' in text}")

Does the linear solve do an all gather : True
Does the LU decomposition do an all gather : True


We see that the issue is that distributed LU decomposition is not implemented.

**Bonus Exercise**: Try to think about how the LU decomposition can be distributed.

In [65]:
import jax
import jax.numpy as jnp
from jax import lax

# LU Decomposition with Partial Pivoting
def lu_decomposition(A):
    """Performs LU decomposition with partial pivoting on matrix A using JAX operations."""
    n = A.shape[0]
    L = jnp.eye(n)  # Initialize L with the identity matrix
    U = A
    P = jnp.eye(n)  # Initialize P as an identity matrix for row pivoting

    for i in range(n):
        # Partial pivoting
        pivot = jnp.argmax(jnp.abs(U[i:, i])) + i
        U = lax.cond(pivot != i, lambda U: U.at[[i, pivot], :].set(U[[pivot, i], :]), lambda U: U, U)
        P = lax.cond(pivot != i, lambda P: P.at[[i, pivot], :].set(P[[pivot, i], :]), lambda P: P, P)
        L = lax.cond(pivot != i, lambda L: L.at[[i, pivot], :i].set(L[[pivot, i], :i]), lambda L: L, L)

        # Update L and U
        for j in range(i + 1, n):
            L = L.at[j, i].set(U[j, i] / U[i, i])
            U = U.at[j, :].set(U[j, :] - L[j, i] * U[i, :])
    
    return P, L, U

# Forward substitution for solving Ly = Pb
def forward_substitution(L, Pb):
    """Solves Ly = Pb for y using forward substitution."""
    n = Pb.shape[0]
    y = jnp.zeros_like(Pb)
    for i in range(n):
        y = y.at[i].set((Pb[i] - jnp.dot(L[i, :i], y[:i])) / L[i, i])
    return y

# Backward substitution for solving Ux = y
def backward_substitution(U, y):
    """Solves Ux = y for x using backward substitution."""
    n = y.shape[0]
    x = jnp.zeros_like(y)
    for i in range(n - 1, -1, -1):
        x = x.at[i].set((y[i] - jnp.dot(U[i, i+1:], x[i+1:])) / U[i, i])
    return x

# Custom solve function using LU decomposition
def jax_solve(A, b):
    """Solves Ax = b using custom LU decomposition with pivoting."""
    P, L, U = lu_decomposition(A)
    Pb = jnp.dot(P, b)  # Apply permutation to b
    y = forward_substitution(L, Pb)
    x = backward_substitution(U, y)
    return x

# Define the matrix A and vector b
A = jax.random.normal(jax.random.PRNGKey(0), (16, 16))
               
b = jax.random.normal(jax.random.PRNGKey(1), (16,))

# Solve using custom jax_solve
x_custom = jax_solve(A, b)

# Solve using jax.numpy.linalg.solve for comparison
x_jax = jax.numpy.linalg.solve(A, b)

# Print and compare the results
assert jnp.allclose(A @ x_custom, b , rtol=1e-5 , atol=1e-5)
assert jnp.allclose(A @ x_jax, b , rtol=1e-5 , atol=1e-5)
assert jnp.allclose(x_jax, x_custom , rtol=1e-5 , atol=1e-5)