<link rel="stylesheet" type="text/css" href="./style.css">

<div class="tutorial-header">

## Tutorial \#5: Sharding

</div>

<link rel="stylesheet" type="text/css" href="./style.css">

<div class="tutorial-text">

Although <b><i>Spark</i></b> is relatively fast, some pipelines simply require more power in order to achieve acceptable performances or even to be able to run them to begin with. This is often achieved by distributing the computation across multiple devices. 

Often, transforming a model from a single-device paradigm into a multi-device paradigm can be challenging and extremely time consuming. However this can be easily achieved in <b><i>Spark</i></b> thanks to JAX's sharding system. In this tutorial will not go deeper into the details of sharding since JAX already does a fantastic job, but for a detailed explanation on sharding we refer the reader to: 

1. [Distributed arrays and automatic parallelization](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)
2. [Explicit sharding](https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html)


Let's start by instructing XLA to simulate 6 physical devices!

</div>

In [1]:
# Set XLA flags
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=6'

# Indicate Jax to use CPU's
os.environ['JAX_PLATFORM_NAME'] = 'cpu'

import jax
import jax.numpy as jnp
print(jax.devices())

[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5)]


<link rel="stylesheet" type="text/css" href="./style.css">

<div class="tutorial-text">

Let's also import the model from the previous tutorial.

</div>

In [2]:
import sys
sys.path.append('..')
import spark

# Load the Brain.
brain_config = spark.nn.BrainConfig.from_file('example_ab_model.scfg')

# Initialize the Brain.
brain = spark.nn.Brain(config=brain_config)

# Build the Brain.
brain(drive=spark.FloatArray(jnp.zeros((4,), dtype=jnp.float16)))

# Inspect the Brain.
brain.inspect()

Brain
├── spiker (TopologicalLinearSpiker)
├── A_ex (ALIFNeuron)
│   ├── soma (AdaptiveLeakySoma)
│   ├── delays (N2NDelays)
│   ├── synapses (LinearSynapses)
│   └── learning_rule (ZenkeRule)
├── A_in (ALIFNeuron)
│   ├── soma (AdaptiveLeakySoma)
│   ├── delays (N2NDelays)
│   ├── synapses (LinearSynapses)
│   └── learning_rule (ZenkeRule)
├── B_ex (ALIFNeuron)
│   ├── soma (AdaptiveLeakySoma)
│   ├── delays (N2NDelays)
│   ├── synapses (LinearSynapses)
│   └── learning_rule (ZenkeRule)
├── B_in (ALIFNeuron)
│   ├── soma (AdaptiveLeakySoma)
│   ├── delays (N2NDelays)
│   ├── synapses (LinearSynapses)
│   └── learning_rule (ZenkeRule)
└── integrator (ExponentialIntegrator)


<link rel="stylesheet" type="text/css" href="./style.css">

<div class="tutorial-text">

By default, JAX arrays and, as a consequence, <b><i>Spark</i></b> models are not commited to any particular device. To see this, we can inspect the internal values of the brain to see where the variables are located. Since there are many many variables inside a Brain, let's just look at three of them.

</div>

In [3]:
print('Array: brain.A_ex.soma.potential')
jax.debug.visualize_array_sharding(brain.A_ex.soma.potential)

print('Array: brain.B_ex.synapses.kernel')
jax.debug.visualize_array_sharding(brain.B_ex.synapses.kernel)

print('Array: brain.integrator.trace.trace_1')
jax.debug.visualize_array_sharding(brain.integrator.trace.trace_1)

print('Array: brain._cache["B_in"]["out_spikes"]')
jax.debug.visualize_array_sharding(brain._cache['B_in']['out_spikes'].value.value) # <-- The cache has a weird internal structure ¯\_(ツ)_/¯ 

Array: brain.A_ex.soma.potential


Array: brain.B_ex.synapses.kernel


Array: brain.integrator.trace.trace_1


Array: brain._cache["B_in"]["out_spikes"]


<link rel="stylesheet" type="text/css" href="./style.css">

<div class="tutorial-text">

As we can see everything is located in the CPU 0. This happens because we have not let JAX now that we wish to use multiple devices. In JAX we often use a <b>Mesh</b> (of devices) to let the JIT compiler know how we want to distribute the data. This is a simple way to define how to paralellize your pipeline across different abstract axes, e.g., the bacth and the model. 

In our case we don't have batches, so let's just start by distributing the model across all devices. This can be easily achieved by creating a mesh and letting JAX know that we want to use this mesh of devices.

</div>

In [4]:
# Create mesh
auto_mesh = jax.sharding.Mesh(
    jax.devices(), 
    axis_names=('device'), 
    axis_types=(jax.sharding.AxisType.Auto,)
)

# Set mesh
jax.set_mesh(auto_mesh)

# Rebuild the Brain.
brain = spark.nn.Brain(config=brain_config)
brain(drive=spark.FloatArray(jnp.zeros((4,), dtype=jnp.float16)));

<link rel="stylesheet" type="text/css" href="./style.css">

<div class="tutorial-text">

We can now inspect the brain again, and see that everything is now in all our devices!

</div>

In [5]:
print('Array: brain.A_ex.soma.potential')
jax.debug.visualize_array_sharding(brain.A_ex.soma.potential)

print('Array: brain.B_ex.synapses.kernel')
jax.debug.visualize_array_sharding(brain.B_ex.synapses.kernel)

print('Array: brain.integrator.trace.trace_1')
jax.debug.visualize_array_sharding(brain.integrator.trace.trace_1)

print('Array: brain._cache["B_in"]["out_spikes"]')
jax.debug.visualize_array_sharding(brain._cache['B_in']['out_spikes'].value.value) # <-- The cache has a weird internal structure ¯\_(ツ)_/¯ 

Array: brain.A_ex.soma.potential


Array: brain.B_ex.synapses.kernel


Array: brain.integrator.trace.trace_1


Array: brain._cache["B_in"]["out_spikes"]


<link rel="stylesheet" type="text/css" href="./style.css">

<div class="tutorial-text">

However, it is important to note that this is not necessarily a good layout. When working with multiple devices, communication matters; a bad distribution of devices may lead to no gain at all from having multiple devices. Therefore, in some cases it may be beneficial to manually assign the devices to sections of our model.

Fortunately this can still be easily achieved, although a little bit more convoluted. For simplicity, let's map all the arrays within a <b><i>Spark</i></b> module to the same devices and everything else to all the devices. 

</div>

In [6]:
from jax.sharding import Mesh, NamedSharding, PartitionSpec

# Shard to device maps.
devices = jax.devices()
shard_to_all_devices  = NamedSharding(Mesh(devices, axis_names=('device')), PartitionSpec())
shard_to_device = [
	NamedSharding(Mesh([devices[i]], axis_names=('device')), PartitionSpec()) for i in range(len(devices))
]

<link rel="stylesheet" type="text/css" href="./style.css">

<div class="tutorial-text">

Next, we need to create a map between arrays and shards, but before that lets explore how this map looks like. We achieve this through the method <b>tree_map_with_path</b>. This allow us to create a map for a pytree using the structure of such pytree. However, in order to use that function, we need to split our model into its graph and its state component.

Now, let's create a dummy map to None just to print all leaf nodes in our pytree, which correspond to all the arrays in our model. The thing to notice here is the path, which we can use to distribute our computation across devices.
</div>

In [7]:
def dummy_map(path: tuple[jax.tree_util.DictKey | jax.tree_util.GetAttrKey], leaf: jax.Array):
	leaf_path = '.'.join([p.key for p in path if hasattr(p, 'key')])
	# Print the path of the leaf
	print(leaf_path)
	return None

graph, state = spark.split((brain))
sharding_rules = jax.tree_util.tree_map_with_path(dummy_map, state)

A_ex.delays._bitmask
A_ex.delays._current_idx
A_ex.delays.rng
A_ex.learning_rule.post_slow_trace.rng
A_ex.learning_rule.post_slow_trace.trace
A_ex.learning_rule.post_trace.rng
A_ex.learning_rule.post_trace.trace
A_ex.learning_rule.pre_trace.rng
A_ex.learning_rule.pre_trace.trace
A_ex.learning_rule.rng
A_ex.learning_rule.target_trace.rng
A_ex.learning_rule.target_trace.trace
A_ex.rng
A_ex.soma.is_ready
A_ex.soma.potential
A_ex.soma.refractory
A_ex.soma.rng
A_ex.soma.threshold.rng
A_ex.soma.threshold.trace
A_ex.synapses.kernel
A_ex.synapses.rng
A_in.delays._bitmask
A_in.delays._current_idx
A_in.delays.rng
A_in.learning_rule.post_slow_trace.rng
A_in.learning_rule.post_slow_trace.trace
A_in.learning_rule.post_trace.rng
A_in.learning_rule.post_trace.trace
A_in.learning_rule.pre_trace.rng
A_in.learning_rule.pre_trace.trace
A_in.learning_rule.rng
A_in.learning_rule.target_trace.rng
A_in.learning_rule.target_trace.trace
A_in.rng
A_in.soma.is_ready
A_in.soma.potential
A_in.soma.refractory
A_in.

<link rel="stylesheet" type="text/css" href="./style.css">

<div class="tutorial-text">

In our current context it makes sense to allocate specific <b><i>Spark</i></b> modules to specific devices, since most modules take similar times to compute and once the input of a module is given most of the computation is performed locally, so no external variables should be required past the input. 

As we can see from the previous dummy map, we can use the name of the main modules of the brain to create this mapping.

</div>

In [8]:
def get_sharding_map(path: tuple[jax.tree_util.DictKey | jax.tree_util.GetAttrKey], leaf: jax.Array) -> NamedSharding:
	leaf_path = [p.key for p in path if hasattr(p, 'key')]
	# Sharding rules
	if 'spiker' in leaf_path[0]:
		return shard_to_device[0]
	elif 'integrator' in leaf_path[0]:
		return shard_to_device[1]
	elif 'A_ex' in leaf_path[0]:
		return shard_to_device[2]
	elif 'A_in' in leaf_path[0]:
		return shard_to_device[3]
	elif 'B_ex' in leaf_path[0]:
		return shard_to_device[4]
	elif 'B_in' in leaf_path[0]:
		return shard_to_device[5]
	else:
		return shard_to_all_devices


# Get sharding map
graph, state = spark.split((brain))
sharding_map = jax.tree_util.tree_map_with_path(get_sharding_map, state)

# Apply the sharding map
state = jax.device_put(state, sharding_map)

# Reconstruct the brain
brain = spark.merge(graph, state)

<link rel="stylesheet" type="text/css" href="./style.css">

<div class="tutorial-text">

Finally, let's just see the final split of our model!

</div>

In [9]:
print('Array: brain.A_ex.soma.potential')
jax.debug.visualize_array_sharding(brain.A_ex.soma.potential)

print('Array: brain.B_ex.synapses.kernel')
jax.debug.visualize_array_sharding(brain.B_ex.synapses.kernel)

print('Array: brain.integrator.trace.trace_1')
jax.debug.visualize_array_sharding(brain.integrator.trace.trace_1)

print('Array: brain._cache["B_in"]["out_spikes"]')
jax.debug.visualize_array_sharding(brain._cache['B_in']['out_spikes'].value.value) # <-- The cache has a weird internal structure ¯\_(ツ)_/¯ 

Array: brain.A_ex.soma.potential


Array: brain.B_ex.synapses.kernel


Array: brain.integrator.trace.trace_1


Array: brain._cache["B_in"]["out_spikes"]
