In [1]:
from typing import Optional

import jax
import jax.numpy as jnp
import numpy as np

In [2]:
jax.local_devices()

[CudaDevice(id=0), CudaDevice(id=1)]

In [3]:
import jax
import jax.numpy as jnp

# Check available devices
print(jax.devices())

# Create two example arrays
x = jnp.arange(8).reshape(2, 4)
y = jnp.ones((2, 4))

# Define a simple function to run on each device
def add_fn(x, y):
    return x + y

# pmap distributes computation over devices
p_add = jax.pmap(add_fn, axis_name='i')

# Run distributed addition
z = p_add(x, y)

print(z)
print(z.devices())


[CudaDevice(id=0), CudaDevice(id=1)]
[[1. 2. 3. 4.]
 [5. 6. 7. 8.]]
{CudaDevice(id=0), CudaDevice(id=1)}


In [7]:
import jax
import jax.numpy as jnp

# Make sure you have 2 devices
print("Devices:", jax.devices())

# Simple matrix multiply function
def matmul_fn(W, X):
    return jnp.dot(W, X)

# Loss function (mean squared output)
def loss_fn(W, X):
    Y = matmul_fn(W, X)
    return jnp.mean(Y ** 2)

# pmap-ed gradient function (distributed)
p_grad_fn = jax.pmap(jax.grad(loss_fn), axis_name='i')

# Initialize data â€” one batch per GPU
key = jax.random.PRNGKey(0)
W = jax.random.normal(key, (2, 4, 4))   # shape (num_devices, ...)
X = jax.random.normal(key, (2, 4, 4))

# Compute distributed gradients
grads = p_grad_fn(W, X)

print("\nGradient results:")
print(grads)
print("\nDevices:", grads.devices())


Devices: [CudaDevice(id=0), CudaDevice(id=1)]

Gradient results:
[[[ 1.1103526   0.11962527 -1.1234112   0.01695487]
  [ 0.45118243 -0.14754063 -0.5328977   0.25843456]
  [ 0.2522853  -0.5790242   2.7051687  -0.99971235]
  [ 0.06262801 -0.19462773  1.4745142   0.7654087 ]]

 [[ 0.8668243   0.16151243  0.43444693 -1.0751637 ]
  [-1.852232    2.1205912   0.22519185  0.86963576]
  [-0.3140492   0.3806017   0.16959918 -0.01568799]
  [-0.16057253 -0.5774332  -0.46479216  0.675294  ]]]

Devices: {CudaDevice(id=0), CudaDevice(id=1)}


In [5]:
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding, PartitionSpec
import jax.numpy as jnp

devices = jax.devices()
mesh = Mesh(mesh_utils.create_device_mesh((2,)), ('dp',))

x = jnp.ones((10, 10))
sharding = NamedSharding(mesh, PartitionSpec('dp', None))
x_sharded = jax.device_put(x, sharding)

from jax import jit

@jit
def f(x):
    return x ** 2

f(x_sharded)

Array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)

In [7]:
import jaxquantum as jqt

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
devices = jax.devices()
mesh = Mesh(mesh_utils.create_device_mesh((2,)), ('dp',))

x = jqt.identity(10).data
sharding = NamedSharding(mesh, PartitionSpec('dp', None))
x_sharded = jax.device_put(x, sharding)

import jax.scipy as jsp

jsp.linalg.eigh(x_sharded)

2025-10-16 11:50:01.121452: E external/xla/xla/service/rendezvous.cc:92] [id=0] This thread has been waiting for `initialize clique for rank 1; clique=devices=[0,1]; stream=0; groups=[[0,1]]; root_device=-1; num_local_participants=2; incarnations=[]; run_id=1126119018` for 10 seconds and may be stuck. All 2 threads joined the rendezvous, however the leader has not marked the rendezvous as completed. Leader can be deadlocked inside the rendezvous callback.


In [None]:
import jax
import jax.numpy as jnp
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding, PartitionSpec
import diffrax

# --- Setup 2-GPU mesh ---
devices = jax.devices()
print("Available devices:", devices)
mesh = Mesh(mesh_utils.create_device_mesh((2,)), ('d',))

# --- Define a small linear system dy/dt = A y ---
n = 8
key = jax.random.PRNGKey(0)
A = jax.random.normal(key, (n, n), dtype=jnp.float32)
y0 = jax.random.normal(key, (n,), dtype=jnp.float32)

# --- Shard A and y across GPUs ---
A_sharding = NamedSharding(mesh, PartitionSpec('d', None))
y_sharding = NamedSharding(mesh, PartitionSpec('d'))
A = jax.device_put(A, A_sharding)
y0 = jax.device_put(y0, y_sharding)

# --- Define RHS and solver ---
def rhs(t, y, A):
    return A @ y  # local matmul per shard


def run_solver(A, y0):
    term = diffrax.ODETerm(rhs)
    solver = diffrax.Tsit5()
    sol = diffrax.diffeqsolve(
        term,
        solver,
        t0=0.0,
        t1=1.0,
        dt0=0.1,
        y0=y0,
        args=A,
        saveat=diffrax.SaveAt(t1=True),
    )
    return sol.ys

# --- Run across 2 GPUs ---
with mesh:
    y_final = run_solver(A, y0)

print("Final y(t=1):", y_final)
print("Sharded across devices:", y_final.devices())


Available devices: [CudaDevice(id=0), CudaDevice(id=1)]


E1016 11:54:40.489626 1700608 spmd_partitioner.cc:630] [spmd] Involuntary full rematerialization. The compiler was not able to go from sharding {devices=[1,2]<=[2]} to {maximal device=0} without doing a full rematerialization of the tensor for HLO operation: %get-tuple-element = f32[1,4]{1,0} get-tuple-element(%param), index=3, sharding={devices=[1,2]<=[2]}. You probably want to enrich the sharding annotations to prevent this from happening.
