In [1]:
import os, sys 
sys.setdlopenflags(os.RTLD_NOW | os.RTLD_GLOBAL)

# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

from mpi4py import MPI
from petsc4py import PETSc

import jax
import jax.lax
import jax.numpy as jnp
import numpy as np
jax.config.update("jax_enable_x64", True)
import time
from timeit import timeit
import matplotlib.pyplot as plt
from jax.sharding import PartitionSpec as P
from jax._src import distributed
from functools import partial

from dolfinx import mesh, fem
import basix 

In [2]:
# jax.distributed.initialize() 
print(f"Backend: {jax.default_backend()}")
cpus = jax.devices("cpu")
print(f"Global devices: {cpus}")
print(f"Local devices: {jax.local_devices()}\n")

Backend: cpu
Global devices: [CpuDevice(id=0)]
Local devices: [CpuDevice(id=0)]



In [3]:
import ipyparallel as ipp
import logging

def run_via_ipyparallel(function, n=8, verbose=True):
    with ipp.Cluster(engines="mpi", n=n, log_level=logging.ERROR) as cluster:
        query = cluster[:].apply_async(function)
        query.wait()
        assert query.successful(), query.error
        if verbose:
            print("".join(query.stdout))
    output = query.get()[0]
    return output

In [4]:
def f():
    import jax
    jax.distributed.initialize()
    print(f"Backend: {jax.default_backend()}")
    print(f"Global devices: {jax.devices()}")
    print(f"Local devices: {jax.local_devices()}\n")

run_via_ipyparallel(f, n=4)

100%|██████████| 4/4 [00:08<00:00,  2.24s/engine]
Backend: cpu
Global devices: [CpuDevice(id=0), CpuDevice(id=131072), CpuDevice(id=262144), CpuDevice(id=393216)]
Local devices: [CpuDevice(id=0)]

Backend: cpu
Global devices: [CpuDevice(id=0), CpuDevice(id=131072), CpuDevice(id=262144), CpuDevice(id=393216)]
Local devices: [CpuDevice(id=131072)]

Backend: cpu
Global devices: [CpuDevice(id=0), CpuDevice(id=131072), CpuDevice(id=262144), CpuDevice(id=393216)]
Local devices: [CpuDevice(id=262144)]

Backend: cpu
Global devices: [CpuDevice(id=0), CpuDevice(id=131072), CpuDevice(id=262144), CpuDevice(id=393216)]
Local devices: [CpuDevice(id=393216)]




In [23]:
def basic_work_with_distributed_arrays():
    import jax
    import jax.numpy as jnp

    print(f"Backend: {jax.default_backend()}")
    print(f"Global devices: {jax.devices()}")
    print(f"Local devices: {jax.local_devices()}\n")

    a = jnp.array([1.0, 2.0, 3.0])
    print(f"Devices: a = {a.devices()}")
run_via_ipyparallel(basic_work_with_distributed_arrays, n=4)

100%|██████████| 4/4 [00:05<00:00,  1.49s/engine]
Backend: cpu
Global devices: [CpuDevice(id=0)]
Local devices: [CpuDevice(id=0)]

Devices: a = {CpuDevice(id=0)}
Backend: cpu
Global devices: [CpuDevice(id=0)]
Local devices: [CpuDevice(id=0)]

Devices: a = {CpuDevice(id=0)}
Backend: cpu
Global devices: [CpuDevice(id=0)]
Local devices: [CpuDevice(id=0)]

Devices: a = {CpuDevice(id=0)}
Backend: cpu
Global devices: [CpuDevice(id=0)]
Local devices: [CpuDevice(id=0)]

Devices: a = {CpuDevice(id=0)}



In [None]:
def basic_work_with_distributed_arrays():
    import jax
    import jax.numpy as jnp

    jax.distributed.initialize()
    print(f"Backend: {jax.default_backend()}")
    print(f"Global devices: {jax.devices()}")
    print(f"Local devices: {jax.local_devices()}\n")

    a = jnp.array([1.0, 2.0, 3.0])
    print(f"Devices: a = {a.devices()}")
run_via_ipyparallel(basic_work_with_distributed_arrays, n=4)

100%|██████████| 4/4 [00:05<00:00,  1.48s/engine]
Backend: cpu
Global devices: [CpuDevice(id=0), CpuDevice(id=131072), CpuDevice(id=262144), CpuDevice(id=393216)]
Local devices: [CpuDevice(id=0)]

Devices: a = {CpuDevice(id=0)}
Backend: cpu
Global devices: [CpuDevice(id=0), CpuDevice(id=131072), CpuDevice(id=262144), CpuDevice(id=393216)]
Local devices: [CpuDevice(id=131072)]

Devices: a = {CpuDevice(id=131072)}
Backend: cpu
Global devices: [CpuDevice(id=0), CpuDevice(id=131072), CpuDevice(id=262144), CpuDevice(id=393216)]
Local devices: [CpuDevice(id=262144)]

Devices: a = {CpuDevice(id=262144)}
Backend: cpu
Global devices: [CpuDevice(id=0), CpuDevice(id=131072), CpuDevice(id=262144), CpuDevice(id=393216)]
Local devices: [CpuDevice(id=393216)]

Devices: a = {CpuDevice(id=393216)}



In [28]:
def basic_work_with_distributed_arrays():
    import os
    os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'
    import jax
    import jax.numpy as jnp
    from jax.sharding import PartitionSpec as P


    print(f"Backend: {jax.default_backend()}")
    print(f"Global devices: {jax.devices()}")
    print(f"Local devices: {jax.local_devices()}\n")

    device_mesh = jax.make_mesh((4,), ('x',))
    sharding = jax.sharding.NamedSharding(device_mesh, P())
    a = jnp.array([1.0, 2.0, 3.0])
    a = jax.device_put(a, sharding)
    print(f"Devices: a = {a.devices()}")
run_via_ipyparallel(basic_work_with_distributed_arrays, n=1)

100%|██████████| 1/1 [00:06<00:00,  6.01s/engine]
Backend: cpu
Global devices: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]
Local devices: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]

Devices: a = {CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)}



## mpi rank == 1

In [None]:
def data_transfer():
    import os
    os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'
    from mpi4py import MPI
    from dolfinx import mesh, fem
    import basix
    import jax
    from jax.sharding import PartitionSpec as P

    jax.distributed.initialize()
    print(f"Backend: {jax.default_backend()}")
    print(f"Global devices: {jax.devices()}")
    print(f"Local devices: {jax.local_devices()}\n")

    N = 10
    domain = mesh.create_unit_square(MPI.COMM_WORLD, N, N, mesh.CellType.triangle)
    Q_element = basix.ufl.quadrature_element(domain.topology.cell_name(), degree=1, value_shape=())
    Q = fem.functionspace(domain, Q_element)
    scale_var = fem.Function(Q)

    if MPI.COMM_WORLD.rank == 0:
        print(f"rank = {MPI.COMM_WORLD.rank} Globally: #DoFs(Q): {Q.dofmap.index_map.size_global:6d}\n", flush=True)
    print(f"rank = {MPI.COMM_WORLD.rank} Locally: #DoFs(V_alpha): {Q.dofmap.index_map.size_local:6d} scale_var {scale_var.x.array.shape}\n", flush=True)

    device_mesh = jax.make_mesh((4,), ('x',))
    sharding = jax.sharding.NamedSharding(device_mesh, P('x'))
    scale_var_values_jax = jax.device_put(scale_var.x.array, sharding)  # measure JAX device transfer time
    print(f"Devices: scale_var_values_jax = {scale_var_values_jax.devices()}")
run_via_ipyparallel(data_transfer, n=1)

100%|██████████| 1/1 [00:05<00:00,  5.93s/engine]
Backend: cpu
Global devices: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]
Local devices: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]

rank = 0 Globally: #DoFs(Q):    200

rank = 0 Locally: #DoFs(V_alpha):    200 scale_var (200,)

Devices: eps_jax = {CpuDevice(id=0), CpuDevice(id=3), CpuDevice(id=1), CpuDevice(id=2)}



## Real application

In [10]:
def constitutive_response(sigma_local, sigma_n_local):
    deps_elas = S_elas @ sigma_local
    sigma_corrected, state = return_mapping(deps_elas, sigma_n_local)
    yielding = state[2]
    return sigma_corrected, (sigma_corrected, yielding)

constitutive_response_v = jax.jit(jax.vmap(constitutive_response, in_axes=(0, 0)))
dconstitutive_response = jax.jacfwd(constitutive_response, has_aux=True)
dconstitutive_response_v = jax.jit(jax.vmap(dconstitutive_response, in_axes=(0, 0)))

In [11]:
sigma_n_local = np.zeros(stress_dim, dtype=PETSc.ScalarType)

NameError: name 'stress_dim' is not defined

In [10]:
stress_dim = 4
R = 0.7
dsigma_path = np.zeros(stress_dim)
angle = 0
# formulas for angle \in [-pi/6, pi/6]
dsigma_path[0] = (R / np.sqrt(2)) * (np.cos(angle) + np.sin(angle) / np.sqrt(3))
dsigma_path[1] = (R / np.sqrt(2)) * (-2 * np.sin(angle) / np.sqrt(3))
dsigma_path[2] = (R / np.sqrt(2)) * (np.sin(angle) / np.sqrt(3) - np.cos(angle))
dsigma_path

array([ 0.49497475, -0.        , -0.49497475,  0.        ])

In [5]:

def solve_standard_problem(N):
    import os
    # os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'
    from mpi4py import MPI
    from dolfinx import mesh, fem, common
    import basix
    import jax
    jax.config.update("jax_enable_x64", True)
    import jax.numpy as jnp
    from jax.sharding import PartitionSpec as P
    from constitutive_model import constitutive_response
    import numpy as np

    domain = mesh.create_unit_square(MPI.COMM_WORLD, N, N, mesh.CellType.triangle)
    stress_dim = 4
    Q_element = basix.ufl.quadrature_element(domain.topology.cell_name(), degree=1, value_shape=(stress_dim,))
    Q = fem.functionspace(domain, Q_element)
    sigma_n = fem.Function(Q)
    sigma = fem.Function(Q)
    sigma_n_np = sigma_n.x.array.reshape((-1, stress_dim))

    local_size = int(sigma_n.x.array.shape[0]/stress_dim)
    dsigma_path_np = np.zeros((local_size, stress_dim))
    R = 0.1
    angle = 0
    # formulas for angle \in [-pi/6, pi/6]
    for i in range(local_size):
        angle = np.random.uniform(-np.pi/6, np.pi/6)
        dsigma_path_np[i,0] = (R / np.sqrt(2)) * (np.cos(angle) + np.sin(angle) / np.sqrt(3))
        dsigma_path_np[i,1] = (R / np.sqrt(2)) * (-2 * np.sin(angle) / np.sqrt(3))
        dsigma_path_np[i,2] = (R / np.sqrt(2)) * (np.sin(angle) / np.sqrt(3) - np.cos(angle))
    # input data

    if MPI.COMM_WORLD.rank == 0:
        print(f"Backend: {jax.default_backend()}")
        print(f"Global devices: {jax.devices()}")
        print(f"Globally: #DoFs(Q): {Q.dofmap.index_map.size_global:6d}\n", flush=True)
    print(f"rank = {MPI.COMM_WORLD.rank} Locally: #DoFs(Q): {Q.dofmap.index_map.size_local:6d} shape(sigma_n_np): {sigma_n_np.shape}", flush=True)
    print(f"rank = {MPI.COMM_WORLD.rank} Local devices: {jax.local_devices()}", flush=True)

    timer = common.Timer("Total_timer")
    dconstitutive_response = jax.jacfwd(constitutive_response, has_aux=True)
    dconstitutive_response_v = jax.jit(jax.vmap(dconstitutive_response, in_axes=(0, 0)))
    timer.start()
    N_loads = 100  # number of loadings or paths
    for i in range(N_loads):
        _, (sigma_corrected, yielding) = dconstitutive_response_v(dsigma_path_np, sigma_n_np)
        sigma_n_np[:] = sigma_corrected
        if MPI.COMM_WORLD.rank == 0:
            print(f"rank = {MPI.COMM_WORLD.rank} yielding max: {jnp.max(yielding)}")
    timer.stop()
    total_time = MPI.COMM_WORLD.allreduce(timer.elapsed()[0], op=MPI.MAX)
    if MPI.COMM_WORLD.rank == 0:
        print(f"rank = {MPI.COMM_WORLD.rank} Total time: {total_time} \n", flush=True)
    print(f"rank = {MPI.COMM_WORLD.rank} sigma_corrected is on: {sigma_corrected.devices()}", flush=True)

    return total_time

In [None]:
N_list = [10, 100, 1000]
times = np.zeros(len(N_list))

for i, N in enumerate(N_list):
    times[i] = run_via_ipyparallel(partial(solve_standard_problem, N), n=4)

100%|██████████| 4/4 [00:09<00:00,  2.25s/engine]
Backend: cpu
Global devices: [CpuDevice(id=0)]
Globally: #DoFs(Q):    200

rank = 0 Locally: #DoFs(Q):     50 shape(sigma_n_np): (58, 4)
rank = 0 Local devices: [CpuDevice(id=0)]
rank = 0 yielding max: -2.207484232215931
rank = 0 yielding max: -2.197140185666778
rank = 0 yielding max: -2.18019448018146
rank = 0 yielding max: -2.1570510075898444
rank = 0 yielding max: -2.1282102327968544
rank = 0 yielding max: -2.0942236459700188
rank = 0 yielding max: -2.0556539426208116
rank = 0 yielding max: -2.0130450206043164
rank = 0 yielding max: -1.9669025121528019
rank = 0 yielding max: -1.9176834206025226
rank = 0 yielding max: -1.8657925587242774
rank = 0 yielding max: -1.8115834866466611
rank = 0 yielding max: -1.7553620705434334
rank = 0 yielding max: -1.697391311710658
rank = 0 yielding max: -1.6378965666893268
rank = 0 yielding max: -1.5770706373821117
rank = 0 yielding max: -1.5150784565297208
rank = 0 yielding max: -1.4520612507586752
ra

In [26]:
times

array([ 14.33,  16.98, 273.75])

In [4]:
def solve_distributed_problem(N):
    import os
    # os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'
    from mpi4py import MPI
    from dolfinx import mesh, fem, common
    import basix
    import jax
    jax.config.update("jax_enable_x64", True)
    import jax.numpy as jnp
    from jax.sharding import PartitionSpec as P
    from constitutive_model import constitutive_response
    import numpy as np

    jax.distributed.initialize()

    domain = mesh.create_unit_square(MPI.COMM_WORLD, N, N, mesh.CellType.triangle)
    stress_dim = 4
    Q_element = basix.ufl.quadrature_element(domain.topology.cell_name(), degree=1, value_shape=(stress_dim,))
    Q = fem.functionspace(domain, Q_element)
    sigma_n = fem.Function(Q)
    sigma = fem.Function(Q)

    R = 0.1
    dsigma_path = np.zeros(stress_dim)
    angle = 0
    # formulas for angle \in [-pi/6, pi/6]
    dsigma_path[0] = (R / np.sqrt(2)) * (np.cos(angle) + np.sin(angle) / np.sqrt(3))
    dsigma_path[1] = (R / np.sqrt(2)) * (-2 * np.sin(angle) / np.sqrt(3))
    dsigma_path[2] = (R / np.sqrt(2)) * (np.sin(angle) / np.sqrt(3) - np.cos(angle))

    # input data
    local_size = int(sigma_n.x.array.shape[0]/stress_dim)
    dsigma_path_np = np.tile(dsigma_path, (local_size, 1))
    sigma_n_np = sigma_n.x.array.reshape((-1, stress_dim))

    if MPI.COMM_WORLD.rank == 0:
        print(f"Backend: {jax.default_backend()}")
        print(f"Global devices: {jax.devices()}")
        print(f"Globally: #DoFs(Q): {Q.dofmap.index_map.size_global:6d}\n", flush=True)
    print(f"rank = {MPI.COMM_WORLD.rank} Locally: #DoFs(Q): {Q.dofmap.index_map.size_local:6d} shape(sigma_n_np): {sigma_n_np.shape}", flush=True)
    print(f"rank = {MPI.COMM_WORLD.rank} Local devices: {jax.local_devices()}", flush=True)

    timer = common.Timer("Total_timer")
    dconstitutive_response = jax.jacfwd(constitutive_response, has_aux=True)
    dconstitutive_response_v = jax.jit(jax.vmap(dconstitutive_response, in_axes=(0, 0)))
    timer.start()
    N_loads = 100  # number of loadings or paths
    for i in range(N_loads):
        _, (sigma_corrected, yielding) = dconstitutive_response_v(dsigma_path_np, sigma_n_np)
        sigma_n_np[:] = sigma_corrected
        print(f"rank = {MPI.COMM_WORLD.rank} yielding max: {jnp.max(yielding)}")
    timer.stop()
    total_time = MPI.COMM_WORLD.allreduce(timer.elapsed()[0], op=MPI.MAX)
    if MPI.COMM_WORLD.rank == 0:
        print(f"rank = {MPI.COMM_WORLD.rank} Total time: {total_time} \n", flush=True)
    print(f"rank = {MPI.COMM_WORLD.rank} sigma_corrected is on: {sigma_corrected.devices()}", flush=True)

    # device_mesh = jax.make_mesh((4,), ('x',))
    # sharding = jax.sharding.NamedSharding(device_mesh, P('x'))
    # scale_var_values_jax = jax.device_put(scale_var.x.array, sharding)  # measure JAX device transfer time
    # print(f"Devices: scale_var_values_jax = {scale_var_values_jax.devices()}")
    return total_time_global

In [7]:
N_list = [10, 100, 1000]
times = np.zeros(len(N_list))

for i, N in enumerate(N_list):
    times[i] = run_via_ipyparallel(partial(solve_distributed_problem, N), n=4)

100%|██████████| 4/4 [00:08<00:00,  2.24s/engine]


KeyboardInterrupt: 

In [11]:
times

array([ 6.4 ,  6.87, 45.94])