In [1]:
import os, sys 
sys.setdlopenflags(os.RTLD_NOW | os.RTLD_GLOBAL)

# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

from mpi4py import MPI
from petsc4py import PETSc

import jax
import jax.lax
import jax.numpy as jnp
import numpy as np
jax.config.update("jax_enable_x64", True)
import time
from timeit import timeit
import matplotlib.pyplot as plt
from jax.sharding import PartitionSpec as P
from jax._src import distributed


from dolfinx import mesh, fem
import basix 

In [2]:
# jax.distributed.initialize() 
print(f"Backend: {jax.default_backend()}")
cpus = jax.devices("cpu")
print(f"Global devices: {cpus}")
print(f"Local devices: {jax.local_devices()}\n")

Backend: cpu
Global devices: [CpuDevice(id=0)]
Local devices: [CpuDevice(id=0)]



In [2]:
import ipyparallel as ipp
import logging

def run_via_ipyparallel(function, n=8, verbose=True):
    with ipp.Cluster(engines="mpi", n=n, log_level=logging.ERROR) as cluster:
        query = cluster[:].apply_async(function)
        query.wait()
        assert query.successful(), query.error
        if verbose:
            print("".join(query.stdout))


In [17]:
def f():
    import jax
    jax.distributed.initialize()
    print(f"Backend: {jax.default_backend()}")
    print(f"Global devices: {jax.devices()}")
    print(f"Local devices: {jax.local_devices()}\n")

run_via_ipyparallel(f, n=4)

100%|██████████| 4/4 [00:06<00:00,  1.50s/engine]
Backend: cpu
Global devices: [CpuDevice(id=0), CpuDevice(id=131072), CpuDevice(id=262144), CpuDevice(id=393216)]
Local devices: [CpuDevice(id=0)]

Backend: cpu
Global devices: [CpuDevice(id=0), CpuDevice(id=131072), CpuDevice(id=262144), CpuDevice(id=393216)]
Local devices: [CpuDevice(id=131072)]

Backend: cpu
Global devices: [CpuDevice(id=0), CpuDevice(id=131072), CpuDevice(id=262144), CpuDevice(id=393216)]
Local devices: [CpuDevice(id=262144)]

Backend: cpu
Global devices: [CpuDevice(id=0), CpuDevice(id=131072), CpuDevice(id=262144), CpuDevice(id=393216)]
Local devices: [CpuDevice(id=393216)]




In [23]:
def basic_work_with_distributed_arrays():
    import jax
    import jax.numpy as jnp

    print(f"Backend: {jax.default_backend()}")
    print(f"Global devices: {jax.devices()}")
    print(f"Local devices: {jax.local_devices()}\n")

    a = jnp.array([1.0, 2.0, 3.0])
    print(f"Devices: a = {a.devices()}")
run_via_ipyparallel(basic_work_with_distributed_arrays, n=4)

100%|██████████| 4/4 [00:05<00:00,  1.49s/engine]
Backend: cpu
Global devices: [CpuDevice(id=0)]
Local devices: [CpuDevice(id=0)]

Devices: a = {CpuDevice(id=0)}
Backend: cpu
Global devices: [CpuDevice(id=0)]
Local devices: [CpuDevice(id=0)]

Devices: a = {CpuDevice(id=0)}
Backend: cpu
Global devices: [CpuDevice(id=0)]
Local devices: [CpuDevice(id=0)]

Devices: a = {CpuDevice(id=0)}
Backend: cpu
Global devices: [CpuDevice(id=0)]
Local devices: [CpuDevice(id=0)]

Devices: a = {CpuDevice(id=0)}



In [None]:
def basic_work_with_distributed_arrays():
    import jax
    import jax.numpy as jnp

    jax.distributed.initialize()
    print(f"Backend: {jax.default_backend()}")
    print(f"Global devices: {jax.devices()}")
    print(f"Local devices: {jax.local_devices()}\n")

    a = jnp.array([1.0, 2.0, 3.0])
    print(f"Devices: a = {a.devices()}")
run_via_ipyparallel(basic_work_with_distributed_arrays, n=4)

100%|██████████| 4/4 [00:05<00:00,  1.48s/engine]
Backend: cpu
Global devices: [CpuDevice(id=0), CpuDevice(id=131072), CpuDevice(id=262144), CpuDevice(id=393216)]
Local devices: [CpuDevice(id=0)]

Devices: a = {CpuDevice(id=0)}
Backend: cpu
Global devices: [CpuDevice(id=0), CpuDevice(id=131072), CpuDevice(id=262144), CpuDevice(id=393216)]
Local devices: [CpuDevice(id=131072)]

Devices: a = {CpuDevice(id=131072)}
Backend: cpu
Global devices: [CpuDevice(id=0), CpuDevice(id=131072), CpuDevice(id=262144), CpuDevice(id=393216)]
Local devices: [CpuDevice(id=262144)]

Devices: a = {CpuDevice(id=262144)}
Backend: cpu
Global devices: [CpuDevice(id=0), CpuDevice(id=131072), CpuDevice(id=262144), CpuDevice(id=393216)]
Local devices: [CpuDevice(id=393216)]

Devices: a = {CpuDevice(id=393216)}



In [28]:
def basic_work_with_distributed_arrays():
    import os
    os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'
    import jax
    import jax.numpy as jnp
    from jax.sharding import PartitionSpec as P


    print(f"Backend: {jax.default_backend()}")
    print(f"Global devices: {jax.devices()}")
    print(f"Local devices: {jax.local_devices()}\n")

    device_mesh = jax.make_mesh((4,), ('x',))
    sharding = jax.sharding.NamedSharding(device_mesh, P())
    a = jnp.array([1.0, 2.0, 3.0])
    a = jax.device_put(a, sharding)
    print(f"Devices: a = {a.devices()}")
run_via_ipyparallel(basic_work_with_distributed_arrays, n=1)

100%|██████████| 1/1 [00:06<00:00,  6.01s/engine]
Backend: cpu
Global devices: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]
Local devices: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]

Devices: a = {CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)}



## mpi rank == 1

In [None]:
def data_transfer():
    import os
    os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'
    from mpi4py import MPI
    from dolfinx import mesh, fem
    import basix
    import jax
    from jax.sharding import PartitionSpec as P

    jax.distributed.initialize()
    print(f"Backend: {jax.default_backend()}")
    print(f"Global devices: {jax.devices()}")
    print(f"Local devices: {jax.local_devices()}\n")

    N = 10
    domain = mesh.create_unit_square(MPI.COMM_WORLD, N, N, mesh.CellType.triangle)
    Q_element = basix.ufl.quadrature_element(domain.topology.cell_name(), degree=1, value_shape=())
    Q = fem.functionspace(domain, Q_element)
    scale_var = fem.Function(Q)

    if MPI.COMM_WORLD.rank == 0:
        print(f"rank = {MPI.COMM_WORLD.rank} Globally: #DoFs(Q): {Q.dofmap.index_map.size_global:6d}\n", flush=True)
    print(f"rank = {MPI.COMM_WORLD.rank} Locally: #DoFs(V_alpha): {Q.dofmap.index_map.size_local:6d} scale_var {scale_var.x.array.shape}\n", flush=True)

    device_mesh = jax.make_mesh((4,), ('x',))
    sharding = jax.sharding.NamedSharding(device_mesh, P('x'))
    scale_var_values_jax = jax.device_put(scale_var.x.array, sharding)  # measure JAX device transfer time
    print(f"Devices: scale_var_values_jax = {scale_var_values_jax.devices()}")
run_via_ipyparallel(data_transfer, n=1)

100%|██████████| 1/1 [00:05<00:00,  5.93s/engine]
Backend: cpu
Global devices: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]
Local devices: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]

rank = 0 Globally: #DoFs(Q):    200

rank = 0 Locally: #DoFs(V_alpha):    200 scale_var (200,)

Devices: eps_jax = {CpuDevice(id=0), CpuDevice(id=3), CpuDevice(id=1), CpuDevice(id=2)}



## Real application

In [2]:
E = 6778  # [MPa] Young modulus
nu = 0.25  # [-] Poisson ratio
c = 3.45  # [MPa] cohesion
phi = 30 * np.pi / 180  # [rad] friction angle
psi = 30 * np.pi / 180  # [rad] dilatancy angle
theta_T = 26 * np.pi / 180  # [rad] transition angle as defined by Abbo and Sloan
a = 0.26 * c / np.tan(phi)  # [MPa] tension cuff-off parameter
stress_dim = 4

def J3(s):
    return s[2] * (s[0] * s[1] - s[3] * s[3] / 2.0)


def J2(s):
    return 0.5 * jnp.vdot(s, s)


def theta(s):
    J2_ = J2(s)
    arg = -(3.0 * np.sqrt(3.0) * J3(s)) / (2.0 * jnp.sqrt(J2_ * J2_ * J2_))
    arg = jnp.clip(arg, -1.0, 1.0)
    theta = 1.0 / 3.0 * jnp.arcsin(arg)
    return theta


def sign(x):
    return jax.lax.cond(x < 0.0, lambda x: -1, lambda x: 1, x)


def coeff1(theta, angle):
    return np.cos(theta_T) - (1.0 / np.sqrt(3.0)) * np.sin(angle) * np.sin(theta_T)


def coeff2(theta, angle):
    return sign(theta) * np.sin(theta_T) + (1.0 / np.sqrt(3.0)) * np.sin(angle) * np.cos(theta_T)


coeff3 = 18.0 * np.cos(3.0 * theta_T) * np.cos(3.0 * theta_T) * np.cos(3.0 * theta_T)


def C(theta, angle):
    return (
        -np.cos(3.0 * theta_T) * coeff1(theta, angle) - 3.0 * sign(theta) * np.sin(3.0 * theta_T) * coeff2(theta, angle)
    ) / coeff3

def B(theta, angle):
    return (
        sign(theta) * np.sin(6.0 * theta_T) * coeff1(theta, angle) - 6.0 * np.cos(6.0 * theta_T) * coeff2(theta, angle)
    ) / coeff3


def A(theta, angle):
    return (
        -(1.0 / np.sqrt(3.0)) * np.sin(angle) * sign(theta) * np.sin(theta_T)
        - B(theta, angle) * sign(theta) * np.sin(3 * theta_T)
        - C(theta, angle) * np.sin(3.0 * theta_T) * np.sin(3.0 * theta_T)
        + np.cos(theta_T)
    )


def K(theta, angle):
    def K_false(theta):
        return jnp.cos(theta) - (1.0 / np.sqrt(3.0)) * np.sin(angle) * jnp.sin(theta)

    def K_true(theta):
        return (
            A(theta, angle)
            + B(theta, angle) * jnp.sin(3.0 * theta)
            + C(theta, angle) * jnp.sin(3.0 * theta) * jnp.sin(3.0 * theta)
        )

    return jax.lax.cond(jnp.abs(theta) > theta_T, K_true, K_false, theta)

def a_g(angle):
    return a * np.tan(phi) / np.tan(angle)

dev = np.array(
    [
        [2.0 / 3.0, -1.0 / 3.0, -1.0 / 3.0, 0.0],
        [-1.0 / 3.0, 2.0 / 3.0, -1.0 / 3.0, 0.0],
        [-1.0 / 3.0, -1.0 / 3.0, 2.0 / 3.0, 0.0],
        [0.0, 0.0, 0.0, 1.0],
    ],
    dtype=PETSc.ScalarType,
)
tr = np.array([1.0, 1.0, 1.0, 0.0], dtype=PETSc.ScalarType)


def surface(sigma_local, angle):
    s = dev @ sigma_local
    I1 = tr @ sigma_local
    theta_ = theta(s)
    return (
        (I1 / 3.0 * np.sin(angle))
        + jnp.sqrt(
            J2(s) * K(theta_, angle) * K(theta_, angle) + a_g(angle) * a_g(angle) * np.sin(angle) * np.sin(angle)
        )
        - c * np.cos(angle)
    )

def f(sigma_local):
    return surface(sigma_local, phi)

def g(sigma_local):
    return surface(sigma_local, psi)

dgdsigma = jax.jacfwd(g)

lmbda = E * nu / ((1.0 + nu) * (1.0 - 2.0 * nu))
mu = E / (2.0 * (1.0 + nu))
C_elas = np.array(
    [
        [lmbda + 2 * mu, lmbda, lmbda, 0],
        [lmbda, lmbda + 2 * mu, lmbda, 0],
        [lmbda, lmbda, lmbda + 2 * mu, 0],
        [0, 0, 0, 2 * mu],
    ],
    dtype=PETSc.ScalarType,
)
S_elas = np.linalg.inv(C_elas)
ZERO_VECTOR = np.zeros(stress_dim, dtype=PETSc.ScalarType)

def deps_p(sigma_local, dlambda, deps_local, sigma_n_local):
    sigma_elas_local = sigma_n_local + C_elas @ deps_local
    yielding = f(sigma_elas_local)

    def deps_p_elastic(sigma_local, dlambda):
        return ZERO_VECTOR

    def deps_p_plastic(sigma_local, dlambda):
        return dlambda * dgdsigma(sigma_local)

    return jax.lax.cond(yielding <= 0.0, deps_p_elastic, deps_p_plastic, sigma_local, dlambda)


def r_g(sigma_local, dlambda, deps_local, sigma_n_local):
    deps_p_local = deps_p(sigma_local, dlambda, deps_local, sigma_n_local)
    return sigma_local - sigma_n_local - C_elas @ (deps_local - deps_p_local)


def r_f(sigma_local, dlambda, deps_local, sigma_n_local):
    sigma_elas_local = sigma_n_local + C_elas @ deps_local
    yielding = f(sigma_elas_local)

    def r_f_elastic(sigma_local, dlambda):
        return dlambda

    def r_f_plastic(sigma_local, dlambda):
        return f(sigma_local)

    return jax.lax.cond(yielding <= 0.0, r_f_elastic, r_f_plastic, sigma_local, dlambda)


def r(y_local, deps_local, sigma_n_local):
    sigma_local = y_local[:stress_dim]
    dlambda_local = y_local[-1]

    res_g = r_g(sigma_local, dlambda_local, deps_local, sigma_n_local)
    res_f = r_f(sigma_local, dlambda_local, deps_local, sigma_n_local)

    res = jnp.c_["0,1,-1", res_g, res_f]  # concatenates an array and a scalar
    return res

drdy = jax.jacfwd(r)

Nitermax, tol = 200, 1e-10

ZERO_SCALAR = np.array([0.0])


def return_mapping(deps_local, sigma_n_local):
    """Performs the return-mapping procedure.

    It solves elastoplastic constitutive equations numerically by applying the
    Newton method in a single Gauss point. The Newton loop is implement via
    `jax.lax.while_loop`.

    The function returns `sigma_local` two times to reuse its values after
    differentiation, i.e. as once we apply
    `jax.jacfwd(return_mapping, has_aux=True)` the ouput function will
    have an output of
    `(C_tang_local, (sigma_local, niter_total, yielding, norm_res, dlambda))`.

    Returns:
        sigma_local: The stress at the current Gauss point.
        niter_total: The total number of iterations.
        yielding: The value of the yield function.
        norm_res: The norm of the residuals.
        dlambda: The value of the plastic multiplier.
    """
    niter = 0

    dlambda = ZERO_SCALAR
    sigma_local = sigma_n_local
    y_local = jnp.concatenate([sigma_local, dlambda])

    res = r(y_local, deps_local, sigma_n_local)
    norm_res0 = jnp.linalg.norm(res)

    def cond_fun(state):
        norm_res, niter, _ = state
        return jnp.logical_and(norm_res / norm_res0 > tol, niter < Nitermax)

    def body_fun(state):
        norm_res, niter, history = state

        y_local, deps_local, sigma_n_local, res = history

        j = drdy(y_local, deps_local, sigma_n_local)
        j_inv_vp = jnp.linalg.solve(j, -res)
        y_local = y_local + j_inv_vp

        res = r(y_local, deps_local, sigma_n_local)
        norm_res = jnp.linalg.norm(res)
        history = y_local, deps_local, sigma_n_local, res

        niter += 1

        return (norm_res, niter, history)

    history = (y_local, deps_local, sigma_n_local, res)

    norm_res, niter_total, y_local = jax.lax.while_loop(cond_fun, body_fun, (norm_res0, niter, history))

    sigma_local = y_local[0][:stress_dim]
    dlambda = y_local[0][-1]
    sigma_elas_local = C_elas @ deps_local
    yielding = f(sigma_n_local + sigma_elas_local)

    return sigma_local, (sigma_local, niter_total, yielding, norm_res, dlambda)

In [4]:
def constitutive_response(sigma_local, sigma_n_local):
    deps_elas = S_elas @ sigma_local
    sigma_corrected, state = return_mapping(deps_elas, sigma_n_local)
    yielding = state[2]
    return sigma_corrected, yielding

constitutive_response_v = jax.jit(jax.vmap(constitutive_response, in_axes=(0, 0)))

In [5]:
sigma_n_local = np.zeros(stress_dim, dtype=PETSc.ScalarType)

In [10]:
stress_dim = 4
R = 0.7
dsigma_path = np.zeros(stress_dim)
angle = 0
# formulas for angle \in [-pi/6, pi/6]
dsigma_path[0] = (R / np.sqrt(2)) * (np.cos(angle) + np.sin(angle) / np.sqrt(3))
dsigma_path[1] = (R / np.sqrt(2)) * (-2 * np.sin(angle) / np.sqrt(3))
dsigma_path[2] = (R / np.sqrt(2)) * (np.sin(angle) / np.sqrt(3) - np.cos(angle))
dsigma_path

array([ 0.49497475, -0.        , -0.49497475,  0.        ])

In [8]:
constitutive_response(dsigma_path, sigma_n_local)

(Array([ 0.49497475,  0.        , -0.49497475,  0.        ], dtype=float64),
 Array(-2.06667052, dtype=float64))

In [34]:
def solve_constitutive_problem():
    import os
    # os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'
    from mpi4py import MPI
    from dolfinx import mesh, fem
    import basix
    import jax
    jax.config.update("jax_enable_x64", True)
    from jax.sharding import PartitionSpec as P
    from constitutive_model import constitutive_response
    import numpy as np

    jax.distributed.initialize()

    N = 10
    domain = mesh.create_unit_square(MPI.COMM_WORLD, N, N, mesh.CellType.triangle)
    stress_dim = 4
    Q_element = basix.ufl.quadrature_element(domain.topology.cell_name(), degree=1, value_shape=(stress_dim,))
    Q = fem.functionspace(domain, Q_element)
    sigma_n = fem.Function(Q)
    sigma = fem.Function(Q)

    R = 0.7
    dsigma_path = np.zeros(stress_dim)
    angle = 0
    # formulas for angle \in [-pi/6, pi/6]
    dsigma_path[0] = (R / np.sqrt(2)) * (np.cos(angle) + np.sin(angle) / np.sqrt(3))
    dsigma_path[1] = (R / np.sqrt(2)) * (-2 * np.sin(angle) / np.sqrt(3))
    dsigma_path[2] = (R / np.sqrt(2)) * (np.sin(angle) / np.sqrt(3) - np.cos(angle))
    
    # input data
    local_size = int(sigma_n.x.array.shape[0]/stress_dim)
    dsigma_path_np = np.tile(dsigma_path, (local_size, 1))
    sigma_n_np = sigma_n.x.array.reshape((-1, stress_dim))

    if MPI.COMM_WORLD.rank == 0:
        print(f"Backend: {jax.default_backend()}")
        print(f"Global devices: {jax.devices()}")
        print(f"Globally: #DoFs(Q): {Q.dofmap.index_map.size_global:6d}\n", flush=True)
    print(f"rank = {MPI.COMM_WORLD.rank} Locally: #DoFs(Q): {Q.dofmap.index_map.size_local:6d} sigma_n_np shape {sigma_n_np.shape}", flush=True)
    print(f"rank = {MPI.COMM_WORLD.rank} Local devices: {jax.local_devices()}", flush=True)


    constitutive_response_v = jax.jit(jax.vmap(constitutive_response, in_axes=(0, 0)))
    sigma_corrected, yielding = constitutive_response_v(dsigma_path_np, sigma_n_np)
    print(f"rank = {MPI.COMM_WORLD.rank} sigma_corrected {sigma_corrected.devices()} yielding max {jnp.max(yielding)=} \n", flush=True)
    
    # device_mesh = jax.make_mesh((4,), ('x',))
    # sharding = jax.sharding.NamedSharding(device_mesh, P('x'))
    # scale_var_values_jax = jax.device_put(scale_var.x.array, sharding)  # measure JAX device transfer time
    # print(f"Devices: scale_var_values_jax = {scale_var_values_jax.devices()}")
run_via_ipyparallel(solve_constitutive_problem, n=4)

100%|██████████| 4/4 [00:05<00:00,  1.50s/engine]
Backend: cpu
Global devices: [CpuDevice(id=0), CpuDevice(id=131072), CpuDevice(id=262144), CpuDevice(id=393216)]
Globally: #DoFs(Q):    200

rank = 0 Locally: #DoFs(Q):     50 sigma_n_np shape (58, 4)
rank = 0 Local devices: [CpuDevice(id=0)]
rank = 0 sigma_corrected {CpuDevice(id=0)} yielding (58,) 

rank = 1 Locally: #DoFs(Q):     49 sigma_n_np shape (61, 4)
rank = 1 Local devices: [CpuDevice(id=131072)]
rank = 1 sigma_corrected {CpuDevice(id=131072)} yielding (61,) 

rank = 2 Locally: #DoFs(Q):     51 sigma_n_np shape (63, 4)
rank = 2 Local devices: [CpuDevice(id=262144)]
rank = 2 sigma_corrected {CpuDevice(id=262144)} yielding (63,) 

rank = 3 Locally: #DoFs(Q):     50 sigma_n_np shape (58, 4)
rank = 3 Local devices: [CpuDevice(id=393216)]
rank = 3 sigma_corrected {CpuDevice(id=393216)} yielding (58,) 


