From 9a307cb9d0fc06942e13056d0e2d6cf2826d6661 Mon Sep 17 00:00:00 2001 From: EiffL Date: Sun, 7 Jul 2024 19:53:54 -0400 Subject: [PATCH 1/9] simplify example prototype --- examples/growth.py | 593 ------------------------------------- examples/jaxdecomp_lpt.py | 299 ------------------- examples/lpt_nbody_demo.py | 273 +++++++++++++++++ examples/scatter.py | 155 ++++++++++ examples/utils.py | 98 ------ 5 files changed, 428 insertions(+), 990 deletions(-) delete mode 100644 examples/growth.py delete mode 100644 examples/jaxdecomp_lpt.py create mode 100644 examples/lpt_nbody_demo.py create mode 100644 examples/scatter.py delete mode 100644 examples/utils.py diff --git a/examples/growth.py b/examples/growth.py deleted file mode 100644 index eadf160..0000000 --- a/examples/growth.py +++ /dev/null @@ -1,593 +0,0 @@ -import jax.numpy as np -from jax_cosmo.background import * -from jax_cosmo.scipy.interpolate import interp -from jax_cosmo.scipy.ode import odeint - -# Taken from jaxpm github/DifferentiableUniverseInitiative/JaxPM - - -def E(cosmo, a): - r"""Scale factor dependent factor E(a) in the Hubble - parameter. - Parameters - ---------- - a : array_like - Scale factor - Returns - ------- - E : ndarray, or float if input scalar - Square of the scaling of the Hubble constant as a function of - scale factor - Notes - ----- - The Hubble parameter at scale factor `a` is given by - :math:`H^2(a) = E^2(a) H_o^2` where :math:`E^2` is obtained through - Friedman's Equation (see :cite:`2005:Percival`) : - .. math:: - E^2(a) = \Omega_m a^{-3} + \Omega_k a^{-2} + \Omega_{de} a^{f(a)} - where :math:`f(a)` is the Dark Energy evolution parameter computed - by :py:meth:`.f_de`. - """ - return np.power(Esqr(cosmo, a), 0.5) - - -def df_de(cosmo, a, epsilon=1e-5): - r"""Derivative of the evolution parameter for the Dark Energy density - f(a) with respect to the scale factor. - Parameters - ---------- - cosmo: Cosmology - Cosmological parameters structure - a : array_like - Scale factor - epsilon: float value - Small number to make sure we are not dividing by 0 and avoid a singularity - Returns - ------- - df(a)/da : ndarray, or float if input scalar - Derivative of the evolution parameter for the Dark Energy density - with respect to the scale factor. - Notes - ----- - The expression for :math:`\frac{df(a)}{da}` is: - .. math:: - \frac{df}{da}(a) = =\frac{3w_a \left( \ln(a-\epsilon)- - \frac{a-1}{a-\epsilon}\right)}{\ln^2(a-\epsilon)} - """ - return (3 * cosmo.wa * (np.log(a - epsilon) - (a - 1) / (a - epsilon)) / - np.power(np.log(a - epsilon), 2)) - - -def dEa(cosmo, a): - r"""Derivative of the scale factor dependent factor E(a) in the Hubble - parameter. - Parameters - ---------- - a : array_like - Scale factor - Returns - ------- - dE(a)/da : ndarray, or float if input scalar - Derivative of the scale factor dependent factor in the Hubble - parameter with respect to the scale factor. - Notes - ----- - The expression for :math:`\frac{dE}{da}` is: - .. math:: - \frac{dE(a)}{da}=\frac{-3a^{-4}\Omega_{0m} - -2a^{-3}\Omega_{0k} - +f'_{de}\Omega_{0de}a^{f_{de}(a)}}{2E(a)} - Notes - ----- - The Hubble parameter at scale factor `a` is given by - :math:`H^2(a) = E^2(a) H_o^2` where :math:`E^2` is obtained through - Friedman's Equation (see :cite:`2005:Percival`) : - .. math:: - E^2(a) = \Omega_m a^{-3} + \Omega_k a^{-2} + \Omega_{de} a^{f(a)} - where :math:`f(a)` is the Dark Energy evolution parameter computed - by :py:meth:`.f_de`. - """ - return (0.5 * - (-3 * cosmo.Omega_m * np.power(a, -4) - - 2 * cosmo.Omega_k * np.power(a, -3) + - df_de(cosmo, a) * cosmo.Omega_de * np.power(a, f_de(cosmo, a))) / - np.power(Esqr(cosmo, a), 0.5)) - - -def growth_factor(cosmo, a): - """Compute linear growth factor D(a) at a given scale factor, - normalized such that D(a=1) = 1. - - Parameters - ---------- - cosmo: `Cosmology` - Cosmology object - - a: array_like - Scale factor - - Returns - ------- - D: ndarray, or float if input scalar - Growth factor computed at requested scale factor - - Notes - ----- - The growth computation will depend on the cosmology parametrization, for - instance if the $\gamma$ parameter is defined, the growth will be computed - assuming the $f = \Omega^\gamma$ growth rate, otherwise the usual ODE for - growth will be solved. - """ - if cosmo._flags["gamma_growth"]: - print(f"here") - return _growth_factor_gamma(cosmo, a) - else: - print(f"THERE") - return _growth_factor_ODE(cosmo, a) - - -def growth_factor_second(cosmo, a): - """Compute second order growth factor D2(a) at a given scale factor, - normalized such that D(a=1) = 1. - - Parameters - ---------- - cosmo: `Cosmology` - Cosmology object - - a: array_like - Scale factor - - Returns - ------- - D2: ndarray, or float if input scalar - Growth factor computed at requested scale factor - - Notes - ----- - The growth computation will depend on the cosmology parametrization, - as for the linear growth. Currently the second order growth - factor is not implemented with $\gamma$ parameter. - """ - if cosmo._flags["gamma_growth"]: - raise NotImplementedError( - "Gamma growth rate is not implemented for second order growth!") - return None - else: - return _growth_factor_second_ODE(cosmo, a) - - -def growth_rate(cosmo, a): - """Compute growth rate dD/dlna at a given scale factor. - - Parameters - ---------- - cosmo: `Cosmology` - Cosmology object - - a: array_like - Scale factor - - Returns - ------- - f: ndarray, or float if input scalar - Growth rate computed at requested scale factor - - Notes - ----- - The growth computation will depend on the cosmology parametrization, for - instance if the $\gamma$ parameter is defined, the growth will be computed - assuming the $f = \Omega^\gamma$ growth rate, otherwise the usual ODE for - growth will be solved. - - The LCDM approximation to the growth rate :math:`f_{\gamma}(a)` is given by: - - .. math:: - - f_{\gamma}(a) = \Omega_m^{\gamma} (a) - - with :math: `\gamma` in LCDM, given approximately by: - .. math:: - - \gamma = 0.55 - - see :cite:`2019:Euclid Preparation VII, eqn.32` - """ - if cosmo._flags["gamma_growth"]: - return _growth_rate_gamma(cosmo, a) - else: - return _growth_rate_ODE(cosmo, a) - - -def growth_rate_second(cosmo, a): - """Compute second order growth rate dD2/dlna at a given scale factor. - - Parameters - ---------- - cosmo: `Cosmology` - Cosmology object - - a: array_like - Scale factor - - Returns - ------- - f2: ndarray, or float if input scalar - Second order growth rate computed at requested scale factor - - Notes - ----- - The growth computation will depend on the cosmology parametrization, - as for the linear growth rate. Currently the second order growth - rate is not implemented with $\gamma$ parameter. - """ - if cosmo._flags["gamma_growth"]: - raise NotImplementedError( - "Gamma growth factor is not implemented for second order growth!") - return None - else: - return _growth_rate_second_ODE(cosmo, a) - - -def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4): - """Compute linear growth factor D(a) at a given scale factor, - normalised such that D(a=1) = 1. - - Parameters - ---------- - a: array_like - Scale factor - - amin: float - Mininum scale factor, default 1e-3 - - Returns - ------- - D: ndarray, or float if input scalar - Growth factor computed at requested scale factor - """ - # Check if growth has already been computed - if not "background.growth_factor" in cosmo._workspace.keys(): - # Compute tabulated array - atab = np.logspace(log10_amin, 0.0, steps) - - def D_derivs(y, x): - q = (2.0 - 0.5 * (Omega_m_a(cosmo, x) + - (1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x))) / x - r = 1.5 * Omega_m_a(cosmo, x) / x / x - - g1, g2 = y[0] - f1, f2 = y[1] - dy1da = [f1, -q * f1 + r * g1] - dy2da = [f2, -q * f2 + r * g2 - r * g1**2] - return np.array([[dy1da[0], dy2da[0]], [dy1da[1], dy2da[1]]]) - - y0 = np.array([[atab[0], -3.0 / 7 * atab[0]**2], [1.0, -6.0 / 7 * atab[0]]]) - y = odeint(D_derivs, y0, atab) - - # compute second order derivatives growth - dyda2 = D_derivs(np.transpose(y, (1, 2, 0)), atab) - dyda2 = np.transpose(dyda2, (2, 0, 1)) - - # Normalize results - y1 = y[:, 0, 0] - gtab = y1 / y1[-1] - y2 = y[:, 0, 1] - g2tab = y2 / y2[-1] - # To transform from dD/da to dlnD/dlna: dlnD/dlna = a / D dD/da - ftab = y[:, 1, 0] / y1[-1] * atab / gtab - f2tab = y[:, 1, 1] / y2[-1] * atab / g2tab - # Similarly for second order derivatives - # Note: these factors are not accessible as parent functions yet - # since it is unclear what to refer to them with. - htab = dyda2[:, 1, 0] / y1[-1] * atab / gtab - h2tab = dyda2[:, 1, 1] / y2[-1] * atab / g2tab - - cache = { - "a": atab, - "g": gtab, - "f": ftab, - "h": htab, - "g2": g2tab, - "f2": f2tab, - "h2": h2tab, - } - cosmo._workspace["background.growth_factor"] = cache - else: - cache = cosmo._workspace["background.growth_factor"] - return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0) - - -def _growth_rate_ODE(cosmo, a): - """Compute growth rate dD/dlna at a given scale factor by solving the linear - growth ODE. - - Parameters - ---------- - cosmo: `Cosmology` - Cosmology object - - a: array_like - Scale factor - - Returns - ------- - f: ndarray, or float if input scalar - Growth rate computed at requested scale factor - """ - # Check if growth has already been computed, if not, compute it - if not "background.growth_factor" in cosmo._workspace.keys(): - _growth_factor_ODE(cosmo, np.atleast_1d(1.0)) - cache = cosmo._workspace["background.growth_factor"] - return interp(a, cache["a"], cache["f"]) - - -def _growth_factor_second_ODE(cosmo, a): - """Compute second order growth factor D2(a) at a given scale factor, - normalised such that D(a=1) = 1. - - Parameters - ---------- - a: array_like - Scale factor - - amin: float - Mininum scale factor, default 1e-3 - - Returns - ------- - D2: ndarray, or float if input scalar - Second order growth factor computed at requested scale factor - """ - # Check if growth has already been computed, if not, compute it - if not "background.growth_factor" in cosmo._workspace.keys(): - _growth_factor_ODE(cosmo, np.atleast_1d(1.0)) - cache = cosmo._workspace["background.growth_factor"] - return interp(a, cache["a"], cache["g2"]) - - -def _growth_rate_ODE(cosmo, a): - """Compute growth rate dD/dlna at a given scale factor by solving the linear - growth ODE. - - Parameters - ---------- - cosmo: `Cosmology` - Cosmology object - - a: array_like - Scale factor - - Returns - ------- - f: ndarray, or float if input scalar - Second order growth rate computed at requested scale factor - """ - # Check if growth has already been computed, if not, compute it - if not "background.growth_factor" in cosmo._workspace.keys(): - _growth_factor_ODE(cosmo, np.atleast_1d(1.0)) - cache = cosmo._workspace["background.growth_factor"] - return interp(a, cache["a"], cache["f"]) - - -def _growth_rate_second_ODE(cosmo, a): - """Compute second order growth rate dD2/dlna at a given scale factor by solving the linear - growth ODE. - - Parameters - ---------- - cosmo: `Cosmology` - Cosmology object - - a: array_like - Scale factor - - Returns - ------- - f2: ndarray, or float if input scalar - Second order growth rate computed at requested scale factor - """ - # Check if growth has already been computed, if not, compute it - if not "background.growth_factor" in cosmo._workspace.keys(): - _growth_factor_ODE(cosmo, np.atleast_1d(1.0)) - cache = cosmo._workspace["background.growth_factor"] - return interp(a, cache["a"], cache["f2"]) - - -def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=128): - r"""Computes growth factor by integrating the growth rate provided by the - \gamma parametrization. Normalized such that D( a=1) =1 - - Parameters - ---------- - a: array_like - Scale factor - - amin: float - Mininum scale factor, default 1e-3 - - Returns - ------- - D: ndarray, or float if input scalar - Growth factor computed at requested scale factor - - """ - # Check if growth has already been computed, if not, compute it - if not "background.growth_factor" in cosmo._workspace.keys(): - # Compute tabulated array - atab = np.logspace(log10_amin, 0.0, steps) - - def integrand(y, loga): - xa = np.exp(loga) - return _growth_rate_gamma(cosmo, xa) - - gtab = np.exp(odeint(integrand, np.log(atab[0]), np.log(atab))) - gtab = gtab / gtab[-1] # Normalize to a=1. - cache = {"a": atab, "g": gtab} - cosmo._workspace["background.growth_factor"] = cache - else: - cache = cosmo._workspace["background.growth_factor"] - return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0) - - -def _growth_rate_gamma(cosmo, a): - r"""Growth rate approximation at scale factor `a`. - - Parameters - ---------- - cosmo: `Cosmology` - Cosmology object - - a : array_like - Scale factor - - Returns - ------- - f_gamma : ndarray, or float if input scalar - Growth rate approximation at the requested scale factor - - Notes - ----- - The LCDM approximation to the growth rate :math:`f_{\gamma}(a)` is given by: - - .. math:: - - f_{\gamma}(a) = \Omega_m^{\gamma} (a) - - with :math: `\gamma` in LCDM, given approximately by: - .. math:: - - \gamma = 0.55 - - see :cite:`2019:Euclid Preparation VII, eqn.32` - """ - return Omega_m_a(cosmo, a)**cosmo.gamma - - -def Gf(cosmo, a): - r""" - FastPM growth factor function - - Parameters - ---------- - cosmo: dict - Cosmology dictionary. - - a : array_like - Scale factor. - - Returns - ------- - Scalar float Tensor : FastPM growth factor function. - - Notes - ----- - - The expression for :math:`Gf(a)` is: - - .. math:: - Gf(a)=D'_{1norm}*a**3*E(a) - """ - f1 = growth_rate(cosmo, a) - g1 = growth_factor(cosmo, a) - D1f = f1 * g1 / a - return D1f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5) - - -def Gf2(cosmo, a): - r""" FastPM second order growth factor function - - Parameters - ---------- - cosmo: dict - Cosmology dictionary. - - a : array_like - Scale factor. - - Returns - ------- - Scalar float Tensor : FastPM second order growth factor function. - - Notes - ----- - - The expression for :math:`Gf_2(a)` is: - - .. math:: - Gf_2(a)=D'_{2norm}*a**3*E(a) - """ - f2 = growth_rate_second(cosmo, a) - g2 = growth_factor_second(cosmo, a) - D2f = f2 * g2 / a - return D2f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5) - - -def dGfa(cosmo, a): - r""" Derivative of Gf against a - - Parameters - ---------- - cosmo: dict - Cosmology dictionary. - - a : array_like - Scale factor. - - Returns - ------- - Scalar float Tensor : the derivative of Gf against a. - - Notes - ----- - - The expression for :math:`gf(a)` is: - - .. math:: - gf(a)=\frac{dGF}{da}= D^{''}_1 * a ** 3 *E(a) +D'_{1norm}*a ** 3 * E'(a) - + 3 * a ** 2 * E(a)*D'_{1norm} - - """ - f1 = growth_rate(cosmo, a) - g1 = growth_factor(cosmo, a) - D1f = f1 * g1 / a - cache = cosmo._workspace['background.growth_factor'] - f1p = cache['h'] / cache['a'] * cache['g'] - f1p = interp(np.log(a), np.log(cache['a']), f1p) - Ea = E(cosmo, a) - return (f1p * a**3 * Ea + D1f * a**3 * dEa(cosmo, a) + 3 * a**2 * Ea * D1f) - - -def dGf2a(cosmo, a): - r""" Derivative of Gf2 against a - - Parameters - ---------- - cosmo: dict - Cosmology dictionary. - - a : array_like - Scale factor. - - Returns - ------- - Scalar float Tensor : the derivative of Gf2 against a. - - Notes - ----- - - The expression for :math:`gf2(a)` is: - - .. math:: - gf_2(a)=\frac{dGF_2}{da}= D^{''}_2 * a ** 3 *E(a) +D'_{2norm}*a ** 3 * E'(a) - + 3 * a ** 2 * E(a)*D'_{2norm} - - """ - f2 = growth_rate_second(cosmo, a) - g2 = growth_factor_second(cosmo, a) - D2f = f2 * g2 / a - cache = cosmo._workspace['background.growth_factor'] - f2p = cache['h2'] / cache['a'] * cache['g2'] - f2p = interp(np.log(a), np.log(cache['a']), f2p) - E = E(cosmo, a) - return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) + 3 * a**2 * E * D2f) diff --git a/examples/jaxdecomp_lpt.py b/examples/jaxdecomp_lpt.py deleted file mode 100644 index 46a472e..0000000 --- a/examples/jaxdecomp_lpt.py +++ /dev/null @@ -1,299 +0,0 @@ -import jax - -import jaxdecomp - -jax.distributed.initialize() -rank = jax.process_index() -size = jax.process_count() - -print(f"Started process {rank} of {size}") - -import argparse -import os -import time -from functools import partial - -import jax.lax as lax -import jax.numpy as jnp -import jax_cosmo as jc -import numpy as np -from growth import dGfa, growth_factor, growth_rate -from jax.experimental import mesh_utils -from jax.experimental.shard_map import shard_map -from jax.sharding import Mesh, NamedSharding -from jax.sharding import PartitionSpec as P -from utils import * - -if __name__ == '__main__': - - jax.config.update('jax_enable_x64', False) - - parser = argparse.ArgumentParser() - - parser.add_argument('-s', '--size', type=int, default=64) - parser.add_argument('-p', '--pdims', type=str, default='1x1') - parser.add_argument('-b', '--box_size', type=int, default=200) - parser.add_argument('-hs', '--halo_size', type=int, default=32) - parser.add_argument('-o', '--output', type=str, default='out') - - args = parser.parse_args() - - print(f"Running with arguments {args}") - - # ********************************* - # Setup - # ********************************* - master_key = jax.random.PRNGKey(42) - key = jax.random.split(master_key, size)[rank] - # Read parameters - pdims = tuple(map(int, args.pdims.split('x'))) - mesh_shape = (args.size, args.size, args.size) - box_size = [float(args.box_size), float(args.box_size), float(args.box_size)] - halo_size = args.halo_size - - output_dir = args.output - # Create output directory recursively - os.makedirs(output_dir, exist_ok=True) - # Create computing mesh - devices = mesh_utils.create_device_mesh(pdims) - mesh = Mesh(devices, axis_names=('y', 'z')) - sharding = jax.sharding.NamedSharding(mesh, P('z', 'y')) - replicate = jax.sharding.NamedSharding(mesh, P()) - - ### Create all initial distributed tensors ### - local_mesh_shape = [ - mesh_shape[0] // pdims[1], mesh_shape[1] // pdims[0], mesh_shape[2] - ] - # Correction for positions relative to the local slice - correct_y = -local_mesh_shape[1] * (rank // pdims[0]) - correct_z = -local_mesh_shape[0] * (rank % pdims[0]) - - # Create gaussian field distributed across the mesh - z = generate_random_field(mesh_shape, sharding, key, local_mesh_shape) - - kvec = fttk(mesh_shape, mesh) - pos = generate_initial_positions(mesh_shape, sharding) - painting_mesh = jnp.zeros_like(z, device=sharding) - - print(f"Local mesh shape {local_mesh_shape}") - print(f"Saving on folder {output_dir}") - print(f"Created initial field {z.shape} and sharding {z.sharding}") - print(f"Created painting mesh with shape {painting_mesh.shape}") - print(f"And sharding {painting_mesh.sharding}") - print(f"Created positions {pos.shape} and sharding {pos.sharding}") - print("Corrected positions for rank {rank} --> ") - print(f" \tare Y: {correct_y} Z: {correct_z}") - print( - f"pos shape {pos.shape} pos sharding = {pos.sharding} shape of local {pos.addressable_data(0).shape}" - ) - - @partial( - shard_map, - mesh=mesh, - in_specs=(P('z', 'y'), P('z', 'y')), - out_specs=P('z', 'y')) - def cic_paint_sharded(mesh: jnp.ndarray, - positions: jnp.ndarray) -> jnp.ndarray: - """ - Distributed part of the CIC painting f - - Parameters - ---------- - mesh : jnp.ndarray with shape ( X, Y, Z) - The mesh onto which mass is painted. - positions : jnp.ndarray with shape (X , Y, Z, 3) - Positions of particles. - - Returns - ------- - jnp.ndarray - The mesh with painted mass. - """ - # Get positions relative to the start of each slice - positions = positions.at[:, :, :, 1].add(correct_y) - positions = positions.at[:, :, :, 0].add(correct_z) - positions = positions.reshape([-1, 3]) - - mesh = jnp.pad(mesh, [(halo_size, halo_size), (halo_size, halo_size), - (0, 0)]) - positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]) - - positions = jnp.expand_dims(positions, 1) - floor = jnp.floor(positions) - - connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], - [1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]]) - - neighboor_coords = floor + connection - kernel = 1. - jnp.abs(positions - neighboor_coords) - kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] - - neighboor_coords_mod = jnp.mod( - neighboor_coords.reshape([-1, 8, 3]).astype('int32'), - jnp.array(mesh.shape)) - - dnums = jax.lax.ScatterDimensionNumbers( - update_window_dims=(), - inserted_window_dims=(0, 1, 2), - scatter_dims_to_operand_dims=(0, 1, 2)) - mesh = lax.scatter_add(mesh, neighboor_coords_mod, kernel.reshape([-1, 8]), - dnums) - - return mesh - - @jax.jit - def cic_paint(mesh: jnp.ndarray, positions: jnp.ndarray) -> jnp.ndarray: - """ - Wrapper function to paint mass onto a mesh using CIC method and - perform halo exchange. - - Parameters - ---------- - mesh : jnp.ndarray with shape ( X, Y, Z) - The mesh onto which mass is painted. - positions : jnp.ndarray with shape (X , Y, Z, 3) - Positions of particles. - - Returns - ------- - jnp.ndarray - The mesh with painted mass after halo exchange. - """ - field = cic_paint_sharded(mesh, positions) - - field = jaxdecomp.halo_exchange( - field, - halo_extents=(halo_size // 2, halo_size // 2, 0), - halo_periods=(True, True, True), - reduce_halo=True) - # Removing the padding - field = jaxdecomp.slice_unpad(field, ((halo_size, halo_size), - (halo_size, halo_size), (0, 0)), - pdims) - - return field - - @partial( - shard_map, - mesh=mesh, - in_specs=(P('z', 'y'), P('z', 'y')), - out_specs=P('z', 'y')) - def interpolate(kfield: jnp.ndarray, kk: jnp.ndarray) -> jnp.ndarray: - """ - Interpolates the power spectrum onto the k-space field. - - Parameters - ---------- - kfield : jnp.ndarray with shape ( X, Y, Z) - The k-space field. - kk : jnp.ndarray - Magnitude of k-vectors. - - Returns - ------- - jnp.ndarray with shape ( X, Y, Z) - The interpolated k-space field. - """ - k = jnp.logspace(-4, 2, 256) - pk = jc.power.linear_matter_power(jc.Planck15(), k) - pk = pk * (mesh_shape[0] / box_size[0]) * (mesh_shape[1] / box_size[1]) * ( - mesh_shape[2] / box_size[2]) - delta_k = kfield * jc.scipy.interpolate.interp(kk.flatten(), k, pk** - 0.5).reshape(kfield.shape) - - return delta_k - - @jax.jit - def forward_fn(z: jnp.ndarray, kvec: tuple, pos: jnp.ndarray, - painting_mesh: jnp.ndarray, a: float) -> tuple: - """ - Computes initial conditions and density field using Lagrangian perturbation theory (LPT). - - Parameters - ---------- - z : jnp.ndarray - The initial Gaussian random field. - kvec : tuple - K-vectors for Fourier Transform. - pos : jnp.ndarray - Initial positions of particles. - painting_mesh : jnp.ndarray - The mesh for mass painting. - a : float - Scale factor. - - Returns - ------- - tuple - Initial conditions and the density field. - """ - kfield = jaxdecomp.fft.pfft3d(z.astype(jnp.complex64)) - - ky, kz, kx = kvec - kk = jnp.sqrt((kx / box_size[0] * mesh_shape[0])**2 + - (ky / box_size[1] * mesh_shape[1])**2 + - (kz / box_size[1] * mesh_shape[1])**2) - - delta_k = interpolate(kfield, kk) - - # Inverse Fourier transform to generate the initial conditions - initial_conditions = jaxdecomp.fft.pifft3d(delta_k).real - - ### Compute LPT displacement - cosmo = jc.Planck15() - a = jnp.atleast_1d(a) - - kernel_lap = jnp.where(kk == 0, 1., 1. / -(kx**2 + ky**2 + kz**2)) - - pot_k = delta_k * kernel_lap - # Forces have to be a Z pencil because they are going to be IFFT back to X pencil - forces_k = -jnp.stack([ - pot_k * 1j / 6.0 * - (8 * jnp.sin(kx) - jnp.sin(2 * kx)), pot_k * 1j / 6.0 * - (8 * jnp.sin(ky) - jnp.sin(2 * ky)), pot_k * 1j / 6.0 * - (8 * jnp.sin(kz) - jnp.sin(2 * kz)) - ], - axis=-1) - - init_force = jnp.stack( - [jaxdecomp.fft.pifft3d(forces_k[..., i]).real for i in range(3)], - axis=-1) - - dx = growth_factor(cosmo, a) * init_force - - p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, - a)) * dx - f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dGfa(cosmo, - a) * init_force - - field = cic_paint(painting_mesh, (pos + dx)) - - return initial_conditions, field - - with mesh: - jit_start = time.perf_counter() - initial_conds, field = forward_fn(z, kvec, pos, painting_mesh, a=1.) - field.block_until_ready() - jit_end = time.perf_counter() - - print(f"JIT done in {jit_end - jit_start}") - - start = time.perf_counter() - initial_conds, field = forward_fn(z, kvec, pos, painting_mesh, a=1.) - field.block_until_ready() - end = time.perf_counter() - - print(f"Execution done in {end - start}") - - with open(f"{output_dir}/log_{rank}.log", 'w') as log_file: - log_file.write(f"JIT time: {jit_end - jit_start}\n") - log_file.write(f"Execution time: {end - start}\n") - - # Saving results - np.save(f'{output_dir}/initial_conditions_{rank}.npy', - initial_conds.addressable_data(0)) - np.save(f'{output_dir}/field_{rank}.npy', field.addressable_data(0)) - - print(f"Finished saved to {output_dir}") - -jax.distributed.shutdown() diff --git a/examples/lpt_nbody_demo.py b/examples/lpt_nbody_demo.py new file mode 100644 index 0000000..28574cb --- /dev/null +++ b/examples/lpt_nbody_demo.py @@ -0,0 +1,273 @@ +import jax +import jaxdecomp + +jax.distributed.initialize() +rank = jax.process_index() +size = jax.process_count() + +print(f"Started process {rank} of {size}") + +import argparse +import os +import time +from functools import partial + +import jax.lax as lax +import jax.numpy as jnp +import jax_cosmo as jc +import numpy as np +from growth import dGfa, growth_factor, growth_rate +from jax.experimental import mesh_utils +from jax.experimental.shard_map import shard_map +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec as P + +def _global_to_local_size(mesh_shape, sharding): + """ Utility function to compute the expected local size of a mesh + given the global size and the sharding strategy. + """ + return mesh_shape # TODO: sort out how to get the information from sharding + +def fttk(mesh_shape: Tuple[int, int, int], sharding) -> list: + """ + Generate Fourier transform wave numbers for a given mesh. + + Parameters + ---------- + mesh_shape : tuple of int + Shape of the mesh grid. + sharding : Any + Sharding strategy for the array. + + Returns + ------- + list + List of wave number arrays for each dimension. + """ + kd = np.fft.fftfreq(mesh_shape[0]) * 2 * np.pi + return [ + jax.make_array_from_callback( + (mesh_shape[0], 1, 1), + sharding=jax.sharding.NamedSharding(sharding.mesh, P('z')), + data_callback=lambda x: kd.reshape([-1, 1, 1])[x]), + jax.make_array_from_callback( + (1, mesh_shape[1], 1), + sharding=jax.sharding.NamedSharding(sharding.mesh, P(None, 'y')), + data_callback=lambda x: kd.reshape([1, -1, 1])[x]), + kd.reshape([1, 1, -1]) + ] + +def gravity_kernel(kvec): + """ Fourier kernel to compute gravitational forces from a Fourier space density field. + + Parameters + ---------- + kvec : tuple of float + Wave vector in Fourier space. + + Returns + ------- + jnp.ndarray + Gravitational kernel. + """ + kx, ky, kz = kvec + kk = jnp.sqrt(kx**2 + ky**2 + kz**2) + laplace_kernel = jnp.where(kk == 0, 1., 1. / (kx**2 + ky**2 + kz**2)) + grav_kernel = [laplace_kernel * 1j / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)), + laplace_kernel * 1j / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)), + laplace_kernel * 1j / 6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz))] + return grav_kernel + +def gaussian_field_and_forces(mesh_shape, box_size, power_spectrum, seed, sharding): + """ + Generate a Gaussian field with a given power spectrum, along with gravitational forces. + + Parameters + ---------- + mesh_shape : tuple of int + Shape of the mesh. + box_size : float + Size of the box. + power_spectrum : callable + Power spectrum function. + seed : int + Seed for the random number generator. + sharding : Any + Sharding strategy for the array. + + Returns + ------- + delta, forces : tuple of jnp.ndarray + The generated Gaussian field and the gravitational forces. + """ + local_mesh_shape = _global_to_local_size(mesh_shape, sharding) + + # Create a distributed field drawn from a Gaussian distribution in real space + delta = jax.make_array_from_single_device_arrays(shape=mesh_shape, + sharding=sharding, + arrays=[jax.random.normal(seed, local_mesh_shape, dtype='float32')]) + + # Compute the Fourier transform of the field + delta_k = jaxdecomp.fft.pfft3d(delta.astype(jnp.complex64)) + + # Compute the Fourier wavenumbers of the field + kx, ky, kz = fttk(mesh_shape, sharding) + kk = jnp.sqrt((kx / box_size * mesh_shape[0])**2 + + (ky / box_size * mesh_shape[1])**2 + + (kz / box_size * mesh_shape[2])**2) + + # Apply power spectrum to Fourier modes + delta_k *= power_spectrum(kk)**0.5 * jnp.prod(mesh_shape) / jnp.prod(box_size) + + # Compute inverse Fourier transform to recover the initial conditions in real space + delta = jaxdecomp.fft.pifft3d(delta_k).real + + # Compute gravitational forces associated with this field + grav_kernel = gravity_kernel([kx, ky, kz]) + forces_k = [g * delta_k for g in grav_kernel] + + # Retrieve the forces in real space by inverse Fourier transforming + forces = jnp.stack([jaxdecomp.fft.pifft3d(f).real for f in forces_k], axis=-1) + + return delta, forces + + +def cic_paint(displacement, halo_size=32): + original_shape = displacement.shape + + mesh = jnp.zeros(original_shape[:-1], dtype='float32') + + # Padding the output array along the two first dimensions + mesh = jnp.pad(mesh, [[halo_size, halo_size], [halo_size, halo_size], [0, 0]]) + + a,b,c = jnp.meshgrid(jnp.arange(local_mesh_shape[0]), + jnp.arange(local_mesh_shape[1]), + jnp.arange(local_mesh_shape[2])) + # adding an offset of size halo size + pmid = jnp.stack([b+halo_size,a+halo_size,c], axis=-1) + pmid = pmid.reshape([-1,3]) + + painted_field = scatter(pmid, displacement.reshape([-1,3]), mesh) + + # Perform halo exchange to get the correct values at the boundaries + painted_field = jaxdecomp.halo_exchange(field, + halo_extents=(halo_size//2, halo_size//2, 0), + halo_periods=(True, True, True), reduce_halo=False) + + # unpadding the output array + field = unpad(painted_field) + return field + +def simulation_fn(cosmology, mesh_shape, box_size, seed, a, sharding): + """ + Run a simulation to generate initial conditions and density field using LPT. + + Parameters + ---------- + mesh_shape : tuple of int + Shape of the mesh. + box_size : float + Size of the box. + power_spectrum : callable + Power spectrum function. + seed : int + Seed for the random number generator. + a : float + Scale factor. + sharding : Any + Sharding strategy for the array. + + Returns + ------- + initial_conditions, field : tuple of jnp.ndarray + Initial conditions and the density field. + """ + # Define the power spectrum + power_spectrum = lambda k: jc.power.linear_matter_power(cosmology, k) + + # Generate a Gaussian field and gravitational forces from a power spectrum + intial_conditions, initial_forces = gaussian_field_and_forces(mesh_shape, box_size, power_spectrum, seed, sharding) + + # Compute the LPT displacement of that particles initialy placed on a regular grid + # would experience at scale factor a, by simple Zeldovich approximation + initial_displacement = jc.background.growth_factor(cosmology, a) * initial_forces + + # Paints the displaced particles on a mesh to obtain the density field + final_field = cic_paint(initial_displacement, halo_size=32) + + return intial_conditions, final_field + +if __name__ == '__main__': + + jax.config.update('jax_enable_x64', False) + + parser = argparse.ArgumentParser() + + parser.add_argument('-s', '--size', type=int, default=64) + parser.add_argument('-p', '--pdims', type=str, default='1x1') + parser.add_argument('-b', '--box_size', type=int, default=200) + parser.add_argument('-hs', '--halo_size', type=int, default=32) + parser.add_argument('-o', '--output', type=str, default='out') + + args = parser.parse_args() + + print(f"Running with arguments {args}") + + # ********************************* + # Setup + # ********************************* + master_key = jax.random.PRNGKey(42) + key = jax.random.split(master_key, size)[rank] + # Read parameters + pdims = tuple(map(int, args.pdims.split('x'))) + mesh_shape = (args.size, args.size, args.size) + box_size = [float(args.box_size), float(args.box_size), float(args.box_size)] + halo_size = args.halo_size + + output_dir = args.output + # Create output directory recursively + os.makedirs(output_dir, exist_ok=True) + + # Create computing mesh + devices = mesh_utils.create_device_mesh(pdims) + mesh = Mesh(devices, axis_names=('y', 'z')) + sharding = jax.sharding.NamedSharding(mesh, P('z', 'y')) + replicate = jax.sharding.NamedSharding(mesh, P()) + + print(f"Saving on folder {output_dir}") + print(f"Created initial field {z.shape} and sharding {z.sharding}") + print(f"Created painting mesh with shape {painting_mesh.shape}") + print(f"And sharding {painting_mesh.sharding}") + print(f"Created positions {pos.shape} and sharding {pos.sharding}") + print("Corrected positions for rank {rank} --> ") + print( + f"pos shape {pos.shape} pos sharding = {pos.sharding} shape of local {pos.addressable_data(0).shape}" + ) + + with mesh: + jit_start = time.perf_counter() + initial_conds, field = intial_conditions(z, kvec, pos, painting_mesh, a=1.) + field.block_until_ready() + jit_end = time.perf_counter() + + print(f"JIT done in {jit_end - jit_start}") + + start = time.perf_counter() + initial_conds, field = forward_fn(z, kvec, pos, painting_mesh, a=1.) + field.block_until_ready() + end = time.perf_counter() + + print(f"Execution done in {end - start}") + + with open(f"{output_dir}/log_{rank}.log", 'w') as log_file: + log_file.write(f"JIT time: {jit_end - jit_start}\n") + log_file.write(f"Execution time: {end - start}\n") + + # Saving results + np.save(f'{output_dir}/initial_conditions_{rank}.npy', + initial_conds.addressable_data(0)) + np.save(f'{output_dir}/field_{rank}.npy', field.addressable_data(0)) + + print(f"Finished saved to {output_dir}") + +jax.distributed.shutdown() diff --git a/examples/scatter.py b/examples/scatter.py new file mode 100644 index 0000000..6172351 --- /dev/null +++ b/examples/scatter.py @@ -0,0 +1,155 @@ +# This file is adapted from the scatter implementation of the pmwd library +# https://github.com/eelregit/pmwd/blob/master/pmwd/scatter.py +# It provides a simple way to perform a scatter operation by chunks and saves +# memory compared to a native jax.lax.scatter. +# Below is the orginal license of the pmwd library: +############################################################################### +# BSD 3-Clause License +# +# Copyright (c) 2021, the pmwd developers +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import jax +import jax.numpy as jnp +import jax.lax as lax +from jax.lax import scan + +def _chunk_split(ptcl_num, chunk_size, *arrays): + """Split and reshape particle arrays into chunks and remainders, with the remainders + preceding the chunks. 0D ones are duplicated as full arrays in the chunks.""" + chunk_size = ptcl_num if chunk_size is None else min(chunk_size, ptcl_num) + remainder_size = ptcl_num % chunk_size + chunk_num = ptcl_num // chunk_size + + remainder = None + chunks = arrays + if remainder_size: + remainder = [x[:remainder_size] if x.ndim != 0 else x for x in arrays] + chunks = [x[remainder_size:] if x.ndim != 0 else x for x in arrays] + + # `scan` triggers errors in scatter and gather without the `full` + chunks = [x.reshape(chunk_num, chunk_size, *x.shape[1:]) if x.ndim != 0 + else jnp.full(chunk_num, x) for x in chunks] + + return remainder, chunks + +def enmesh(i1, d1, a1, s1, b12, a2, s2): + """Multilinear enmeshing.""" + i1 = jnp.asarray(i1) + d1 = jnp.asarray(d1) + a1 = jnp.float64(a1) if a2 is not None else jnp.array(a1, dtype=d1.dtype) + if s1 is not None: + s1 = jnp.array(s1, dtype=i1.dtype) + b12 = jnp.float64(b12) + if a2 is not None: + a2 = jnp.float64(a2) + if s2 is not None: + s2 = jnp.array(s2, dtype=i1.dtype) + + dim = i1.shape[1] + neighbors = (jnp.arange(2**dim, dtype=i1.dtype)[:, jnp.newaxis] + >> jnp.arange(dim, dtype=i1.dtype) + ) & 1 + + if a2 is not None: + P = i1 * a1 + d1 - b12 + P = P[:, jnp.newaxis] # insert neighbor axis + i2 = P + neighbors * a2 # multilinear + + if s1 is not None: + L = s1 * a1 + i2 %= L + + i2 //= a2 + d2 = P - i2 * a2 + + if s1 is not None: + d2 -= jnp.rint(d2 / L) * L # also abs(d2) < a2 is expected + + i2 = i2.astype(i1.dtype) + d2 = d2.astype(d1.dtype) + a2 = a2.astype(d1.dtype) + + d2 /= a2 + else: + i12, d12 = jnp.divmod(b12, a1) + i1 -= i12.astype(i1.dtype) + d1 -= d12.astype(d1.dtype) + + # insert neighbor axis + i1 = i1[:, jnp.newaxis] + d1 = d1[:, jnp.newaxis] + + # multilinear + d1 /= a1 + i2 = jnp.floor(d1).astype(i1.dtype) + i2 += neighbors + d2 = d1 - i2 + i2 += i1 + + if s1 is not None: + i2 %= s1 + + f2 = 1 - jnp.abs(d2) + + if s1 is None and s2 is not None: # all i2 >= 0 if s1 is not None + i2 = jnp.where(i2 < 0, s2, i2) + + f2 = f2.prod(axis=-1) + + return i2, f2 + +def scatter(pmid, disp, mesh, chunk_size=2**24, val=1., offset=0, cell_size=1.): + ptcl_num, spatial_ndim = pmid.shape + val = jnp.asarray(val) + mesh = jnp.asarray(mesh) + + remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val) + + carry = mesh, offset, cell_size + if remainder is not None: + carry = _scatter_chunk(carry, remainder)[0] + carry = scan(_scatter_chunk, carry, chunks)[0] + mesh = carry[0] + + return mesh + +def _scatter_chunk(carry, chunk): + mesh, offset, cell_size = carry + pmid, disp, val = chunk + spatial_ndim = pmid.shape[1] + spatial_shape = mesh.shape + + # multilinear mesh indices and fractions + ind, frac = enmesh(pmid, disp, cell_size, spatial_shape, + offset, cell_size, spatial_shape) + # scatter + ind = tuple(ind[..., i] for i in range(spatial_ndim)) + mesh = mesh.at[ind].add(val * frac) + + carry = mesh, offset, cell_size + return carry, None \ No newline at end of file diff --git a/examples/utils.py b/examples/utils.py deleted file mode 100644 index 48a8779..0000000 --- a/examples/utils.py +++ /dev/null @@ -1,98 +0,0 @@ -from functools import partial -from typing import Any, Optional, Tuple - -import jax -import jax.numpy as jnp -import numpy as np -from jax.lax import scan -from jax.sharding import NamedSharding -from jax.sharding import PartitionSpec as P - - -def generate_random_field( - mesh_shape: Tuple[int, int, int], sharding: NamedSharding, - key: jax.random.PRNGKey, local_mesh_shape: Tuple[int, int, - int]) -> jnp.ndarray: - """ - Generate a random field using a normal distribution. - - Parameters - ---------- - mesh_shape : tuple of int - Shape of the full mesh. - sharding : Any - Sharding strategy for the array. - key : jax.random.PRNGKey - Random key for generating the noise. - local_mesh_shape : tuple of int - Shape of the local mesh. - - Returns - ------- - jnp.ndarray - Generated random field. - """ - return jax.make_array_from_single_device_arrays( - shape=mesh_shape, - sharding=sharding, - arrays=[jax.random.normal(key, local_mesh_shape, dtype='float32')]) - - -def generate_initial_positions(mesh_shape: Tuple[int, int, int], - sharding: NamedSharding) -> jnp.ndarray: - """ - Generate initial positions for particles on a mesh grid. - - Parameters - ---------- - mesh_shape : tuple of int - Shape of the mesh grid. - sharding : Any - Sharding strategy for the array. - - Returns - ------- - jnp.ndarray - Initial positions on the mesh grid. - """ - pos = jax.make_array_from_callback( - shape=tuple([*mesh_shape, 3]), - sharding=sharding, - data_callback=lambda x: jnp.stack( - jnp.meshgrid( - jnp.arange(mesh_shape[0])[x[0]], - jnp.arange(mesh_shape[1])[x[1]], - jnp.arange(mesh_shape[2]), - indexing='ij'), - axis=-1)) - return pos - - -def fttk(mesh_shape: Tuple[int, int, int], mesh: jax.sharding.Mesh) -> list: - """ - Generate Fourier transform wave numbers for a given mesh. - - Parameters - ---------- - mesh_shape : tuple of int - Shape of the mesh grid. - mesh : Any - Mesh object for sharding. - - Returns - ------- - list - List of wave number arrays for each dimension. - """ - kd = np.fft.fftfreq(mesh_shape[0]).astype('float32') * 2 * np.pi - return [ - jax.make_array_from_callback( - (mesh_shape[0], 1, 1), - sharding=jax.sharding.NamedSharding(mesh, P('z')), - data_callback=lambda x: kd.reshape([-1, 1, 1])[x]), - jax.make_array_from_callback( - (1, mesh_shape[1], 1), - sharding=jax.sharding.NamedSharding(mesh, P(None, 'y')), - data_callback=lambda x: kd.reshape([1, -1, 1])[x]), - kd.reshape([1, 1, -1]) - ] From e465af35de80800f36b77571dd2a49840153adad Mon Sep 17 00:00:00 2001 From: EiffL Date: Mon, 8 Jul 2024 02:04:57 -0400 Subject: [PATCH 2/9] updating implementation --- .pre-commit-config.yaml | 7 - examples/lpt_nbody_demo.py | 404 +++++++++++++++++++---------------- examples/scatter.py | 183 ++++++++-------- examples/submit_rusty.sbatch | 20 ++ 4 files changed, 337 insertions(+), 277 deletions(-) create mode 100644 examples/submit_rusty.sbatch diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c7141cd..f44eaca 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,10 +15,3 @@ repos: hooks: - id: isort name: isort (python) -- repo: https://github.com/pre-commit/mirrors-clang-format - rev: v18.1.4 - hooks: - - id: clang-format - files: '\.(c|cc|cpp|h|hpp|cxx|hh|cu|cuh)$' - exclude: '^third_party/|/pybind11/' - name: clang-format diff --git a/examples/lpt_nbody_demo.py b/examples/lpt_nbody_demo.py index 28574cb..07044e7 100644 --- a/examples/lpt_nbody_demo.py +++ b/examples/lpt_nbody_demo.py @@ -1,40 +1,39 @@ -import jax -import jaxdecomp - -jax.distributed.initialize() -rank = jax.process_index() -size = jax.process_count() - -print(f"Started process {rank} of {size}") - import argparse import os -import time from functools import partial +from typing import Tuple + +import jax + +jax.config.update('jax_enable_x64', False) -import jax.lax as lax import jax.numpy as jnp import jax_cosmo as jc import numpy as np -from growth import dGfa, growth_factor, growth_rate from jax.experimental import mesh_utils from jax.experimental.shard_map import shard_map from jax.sharding import Mesh, NamedSharding from jax.sharding import PartitionSpec as P +from scatter import scatter + +import jaxdecomp + def _global_to_local_size(mesh_shape, sharding): - """ Utility function to compute the expected local size of a mesh - given the global size and the sharding strategy. - """ - return mesh_shape # TODO: sort out how to get the information from sharding + """ Utility function to compute the expected local size of a mesh + given the global size and the sharding strategy. + """ + pdims = sharding.mesh.devices.shape + return [mesh_shape[0] // pdims[1], mesh_shape[1] // pdims[0], mesh_shape[2]] -def fttk(mesh_shape: Tuple[int, int, int], sharding) -> list: - """ + +def fttk(nc: int, sharding) -> list: + """ Generate Fourier transform wave numbers for a given mesh. Parameters ---------- - mesh_shape : tuple of int + mesh_shape : int Shape of the mesh grid. sharding : Any Sharding strategy for the array. @@ -44,230 +43,273 @@ def fttk(mesh_shape: Tuple[int, int, int], sharding) -> list: list List of wave number arrays for each dimension. """ - kd = np.fft.fftfreq(mesh_shape[0]) * 2 * np.pi - return [ + kd = np.fft.fftfreq(nc) * 2 * np.pi + return [ jax.make_array_from_callback( - (mesh_shape[0], 1, 1), - sharding=jax.sharding.NamedSharding(sharding.mesh, P('z')), + (nc, 1, 1), + sharding=NamedSharding(sharding.mesh, P('z')), data_callback=lambda x: kd.reshape([-1, 1, 1])[x]), jax.make_array_from_callback( - (1, mesh_shape[1], 1), - sharding=jax.sharding.NamedSharding(sharding.mesh, P(None, 'y')), + (1, nc, 1), + sharding=NamedSharding(sharding.mesh, P(None, 'y')), data_callback=lambda x: kd.reshape([1, -1, 1])[x]), kd.reshape([1, 1, -1]) - ] + ] + def gravity_kernel(kvec): - """ Fourier kernel to compute gravitational forces from a Fourier space density field. + """ Fourier kernel to compute gravitational forces from a Fourier space density field. Parameters ---------- kvec : tuple of float Wave vector in Fourier space. - + Returns ------- jnp.ndarray Gravitational kernel. """ - kx, ky, kz = kvec - kk = jnp.sqrt(kx**2 + ky**2 + kz**2) - laplace_kernel = jnp.where(kk == 0, 1., 1. / (kx**2 + ky**2 + kz**2)) - grav_kernel = [laplace_kernel * 1j / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)), - laplace_kernel * 1j / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)), - laplace_kernel * 1j / 6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz))] - return grav_kernel - -def gaussian_field_and_forces(mesh_shape, box_size, power_spectrum, seed, sharding): - """ + kx, ky, kz = kvec + kk = kx**2 + ky**2 + kz**2 + laplace_kernel = jnp.where(kk == 0, 1., 1. / kk) + grav_kernel = [ + laplace_kernel * 1j / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)), + laplace_kernel * 1j / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)), + laplace_kernel * 1j / 6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)) + ] + return grav_kernel + + +def gaussian_field_and_forces(key, nc, box_size, power_spectrum, sharding): + """ Generate a Gaussian field with a given power spectrum, along with gravitational forces. - + Parameters ---------- - mesh_shape : tuple of int - Shape of the mesh. + key : int + key for the random number generator. + nc : int + Number of cells in the mesh. box_size : float Size of the box. power_spectrum : callable Power spectrum function. - seed : int - Seed for the random number generator. sharding : Any Sharding strategy for the array. - + Returns ------- delta, forces : tuple of jnp.ndarray The generated Gaussian field and the gravitational forces. """ - local_mesh_shape = _global_to_local_size(mesh_shape, sharding) - - # Create a distributed field drawn from a Gaussian distribution in real space - delta = jax.make_array_from_single_device_arrays(shape=mesh_shape, - sharding=sharding, - arrays=[jax.random.normal(seed, local_mesh_shape, dtype='float32')]) - - # Compute the Fourier transform of the field - delta_k = jaxdecomp.fft.pfft3d(delta.astype(jnp.complex64)) - - # Compute the Fourier wavenumbers of the field - kx, ky, kz = fttk(mesh_shape, sharding) - kk = jnp.sqrt((kx / box_size * mesh_shape[0])**2 + - (ky / box_size * mesh_shape[1])**2 + - (kz / box_size * mesh_shape[2])**2) - - # Apply power spectrum to Fourier modes - delta_k *= power_spectrum(kk)**0.5 * jnp.prod(mesh_shape) / jnp.prod(box_size) - - # Compute inverse Fourier transform to recover the initial conditions in real space - delta = jaxdecomp.fft.pifft3d(delta_k).real - - # Compute gravitational forces associated with this field - grav_kernel = gravity_kernel([kx, ky, kz]) - forces_k = [g * delta_k for g in grav_kernel] - - # Retrieve the forces in real space by inverse Fourier transforming - forces = jnp.stack([jaxdecomp.fft.pifft3d(f).real for f in forces_k], axis=-1) - - return delta, forces - - -def cic_paint(displacement, halo_size=32): - original_shape = displacement.shape - - mesh = jnp.zeros(original_shape[:-1], dtype='float32') - + mesh_shape = (nc,) * 3 + local_mesh_shape = _global_to_local_size(mesh_shape, sharding) + + # Create a distributed field drawn from a Gaussian distribution in real space + delta = jax.make_array_from_single_device_arrays( + shape=mesh_shape, + sharding=sharding, + arrays=[jax.random.normal(key, local_mesh_shape, dtype='float32')]) + + # Compute the Fourier transform of the field + delta_k = jaxdecomp.fft.pfft3d(delta.astype(jnp.complex64)) + + # Compute the Fourier wavenumbers of the field + kx, ky, kz = fttk(nc, sharding) + kk = jnp.sqrt(kx**2 + ky**2 + kz**2) * (nc / box_size)**3 + + # Apply power spectrum to Fourier modes + delta_k *= (power_spectrum(kk) * (nc / box_size)**3)**0.5 + + # Compute inverse Fourier transform to recover the initial conditions in real space + delta = jaxdecomp.fft.pifft3d(delta_k).real + + # Compute gravitational forces associated with this field + grav_kernel = gravity_kernel([kx, ky, kz]) + forces_k = [g * delta_k for g in grav_kernel] + + # Retrieve the forces in real space by inverse Fourier transforming + forces = jnp.stack([jaxdecomp.fft.pifft3d(f).real for f in forces_k], axis=-1) + + return delta, forces + + +def cic_paint(displacement, sharding, halo_size): + """ Paints particles on a mesh using Cloud-In-Cell interpolation. + + Parameters + ---------- + displacement : jnp.ndarray + Displacement field of particles. + sharding : Any + Sharding strategy for the array. + halo_size : int + Halo size for painting. + + Returns + ------- + jnp.ndarray + Density field. + """ + local_mesh_shape = _global_to_local_size(displacement.shape, sharding) + + @partial( + shard_map, + mesh=sharding.mesh, + in_specs=(P('z', 'y'),), + out_specs=P('z', 'y')) + def cic_op(disp): + """ CiC operation on each local slice of the mesh.""" + # Create a mesh to paint the particles on for the local slice + mesh = jnp.zeros(disp.shape[:-1], dtype='float32') + # Padding the output array along the two first dimensions - mesh = jnp.pad(mesh, [[halo_size, halo_size], [halo_size, halo_size], [0, 0]]) + mesh = jnp.pad(mesh, + [[halo_size, halo_size], [halo_size, halo_size], [0, 0]]) + + a, b, c = jnp.meshgrid( + jnp.arange(local_mesh_shape[0]), jnp.arange(local_mesh_shape[1]), + jnp.arange(local_mesh_shape[2])) - a,b,c = jnp.meshgrid(jnp.arange(local_mesh_shape[0]), - jnp.arange(local_mesh_shape[1]), - jnp.arange(local_mesh_shape[2])) # adding an offset of size halo size - pmid = jnp.stack([b+halo_size,a+halo_size,c], axis=-1) - pmid = pmid.reshape([-1,3]) - - painted_field = scatter(pmid, displacement.reshape([-1,3]), mesh) - - # Perform halo exchange to get the correct values at the boundaries - painted_field = jaxdecomp.halo_exchange(field, - halo_extents=(halo_size//2, halo_size//2, 0), - halo_periods=(True, True, True), reduce_halo=False) - - # unpadding the output array - field = unpad(painted_field) - return field - -def simulation_fn(cosmology, mesh_shape, box_size, seed, a, sharding): - """ + pmid = jnp.stack([b + halo_size, a + halo_size, c], axis=-1) + return scatter(pmid.reshape([-1, 3]), disp.reshape([-1, 3]), mesh) + + # Performs painting on padded mesh + field = cic_op(displacement) + + # Run halo exchange to get the correct values at the boundaries + field = jaxdecomp.halo_exchange( + field, + halo_extents=(halo_size // 2, halo_size // 2, 0), + halo_periods=(True, True, True)) + + @partial( + shard_map, + mesh=sharding.mesh, + in_specs=(P('z', 'y'),), + out_specs=P('z', 'y')) + def unpad(x): + """ Unpad the output array. """ + x = x.at[halo_size:halo_size + halo_size // 2].add(x[:halo_size // 2]) + x = x.at[-(halo_size + halo_size // 2):-halo_size].add(x[-halo_size // 2:]) + x = x.at[:, halo_size:halo_size + halo_size // 2].add(x[:, :halo_size // 2]) + x = x.at[:, + -(halo_size + halo_size // 2):-halo_size].add(x[:, + -halo_size // 2:]) + return x[halo_size:-halo_size, halo_size:-halo_size, :] + + # Unpad the output array + field = unpad(field) + return field + + +def simulation_fn(key, nc, box_size, sharding, halo_size, a=1.0): + """ Run a simulation to generate initial conditions and density field using LPT. - + Parameters ---------- - mesh_shape : tuple of int - Shape of the mesh. + key : list of int + Jax random key for the random number generator. + nc : int + Size of the mesh grid. box_size : float Size of the box. - power_spectrum : callable - Power spectrum function. - seed : int - Seed for the random number generator. - a : float - Scale factor. sharding : Any - Sharding strategy for the array. - + Sharding strategy for the simulation. + halo_size: int + Halo size for painting. + a : float + Scale factor of final field. + Returns ------- initial_conditions, field : tuple of jnp.ndarray Initial conditions and the density field. - """ - # Define the power spectrum - power_spectrum = lambda k: jc.power.linear_matter_power(cosmology, k) + """ + # Build a default cosmology + cosmology = jc.Planck15() - # Generate a Gaussian field and gravitational forces from a power spectrum - intial_conditions, initial_forces = gaussian_field_and_forces(mesh_shape, box_size, power_spectrum, seed, sharding) + # Create a small function to generate the linear matter power spectrum at arbitrary k + k = jnp.logspace(-4, 1, 128) + pk = jc.power.linear_matter_power(cosmology, k) + pk_fn = jax.jit(lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk). + reshape(x.shape)) - # Compute the LPT displacement of that particles initialy placed on a regular grid - # would experience at scale factor a, by simple Zeldovich approximation - initial_displacement = jc.background.growth_factor(cosmology, a) * initial_forces - - # Paints the displaced particles on a mesh to obtain the density field - final_field = cic_paint(initial_displacement, halo_size=32) + # Generate a Gaussian field and gravitational forces from a power spectrum + intial_conditions, initial_forces = gaussian_field_and_forces( + key=key, + nc=nc, + box_size=box_size, + power_spectrum=pk_fn, + sharding=sharding) - return intial_conditions, final_field + # Compute the LPT displacement of that particles initialy placed on a regular grid + # would experience at scale factor a, by simple Zeldovich approximation + initial_displacement = jc.background.growth_factor( + cosmology, jnp.atleast_1d(a)) * initial_forces -if __name__ == '__main__': + # Paints the displaced particles on a mesh to obtain the density field + final_field = cic_paint(initial_displacement, sharding, halo_size) - jax.config.update('jax_enable_x64', False) + return intial_conditions, final_field - parser = argparse.ArgumentParser() - - parser.add_argument('-s', '--size', type=int, default=64) - parser.add_argument('-p', '--pdims', type=str, default='1x1') - parser.add_argument('-b', '--box_size', type=int, default=200) - parser.add_argument('-hs', '--halo_size', type=int, default=32) - parser.add_argument('-o', '--output', type=str, default='out') - - args = parser.parse_args() +def main(args): print(f"Running with arguments {args}") - # ********************************* - # Setup - # ********************************* + # Setting up distributed jax + jax.distributed.initialize() + rank = jax.process_index() + size = jax.process_count() + + # Setting up distributed random numbers master_key = jax.random.PRNGKey(42) key = jax.random.split(master_key, size)[rank] - # Read parameters - pdims = tuple(map(int, args.pdims.split('x'))) - mesh_shape = (args.size, args.size, args.size) - box_size = [float(args.box_size), float(args.box_size), float(args.box_size)] - halo_size = args.halo_size - output_dir = args.output - # Create output directory recursively - os.makedirs(output_dir, exist_ok=True) - - # Create computing mesh + # Create computing mesh and sharding information + pdims = tuple(map(int, args.pdims.split('x'))) devices = mesh_utils.create_device_mesh(pdims) mesh = Mesh(devices, axis_names=('y', 'z')) sharding = jax.sharding.NamedSharding(mesh, P('z', 'y')) - replicate = jax.sharding.NamedSharding(mesh, P()) - - print(f"Saving on folder {output_dir}") - print(f"Created initial field {z.shape} and sharding {z.sharding}") - print(f"Created painting mesh with shape {painting_mesh.shape}") - print(f"And sharding {painting_mesh.sharding}") - print(f"Created positions {pos.shape} and sharding {pos.sharding}") - print("Corrected positions for rank {rank} --> ") - print( - f"pos shape {pos.shape} pos sharding = {pos.sharding} shape of local {pos.addressable_data(0).shape}" - ) with mesh: - jit_start = time.perf_counter() - initial_conds, field = intial_conditions(z, kvec, pos, painting_mesh, a=1.) - field.block_until_ready() - jit_end = time.perf_counter() - - print(f"JIT done in {jit_end - jit_start}") - - start = time.perf_counter() - initial_conds, field = forward_fn(z, kvec, pos, painting_mesh, a=1.) - field.block_until_ready() - end = time.perf_counter() - - print(f"Execution done in {end - start}") - - with open(f"{output_dir}/log_{rank}.log", 'w') as log_file: - log_file.write(f"JIT time: {jit_end - jit_start}\n") - log_file.write(f"Execution time: {end - start}\n") - - # Saving results + # Run the simulation on the compute mesh + initial_conds, final_field = simulation_fn( + key=key, + nc=args.nc, + box_size=args.box_size, + sharding=sharding, + halo_size=args.halo_size) + + # Create output directory to save the results + output_dir = args.output + os.makedirs(output_dir, exist_ok=True) np.save(f'{output_dir}/initial_conditions_{rank}.npy', initial_conds.addressable_data(0)) - np.save(f'{output_dir}/field_{rank}.npy', field.addressable_data(0)) - + np.save(f'{output_dir}/field_{rank}.npy', final_field.addressable_data(0)) print(f"Finished saved to {output_dir}") -jax.distributed.shutdown() + # Closing distributed jax + jax.distributed.shutdown() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser("LPT N-body simulation with JAX.") + parser.add_argument( + '--pdims', type=str, default='1x1', help="Processor grid dimensions") + parser.add_argument( + '--nc', type=int, default=256, help="Number of cells in the mesh") + parser.add_argument( + '--box_size', + type=float, + default=256., + help="Size of the simulation box in Mpc/h") + parser.add_argument( + '--halo_size', type=int, default=32, help="Halo size for painting") + parser.add_argument('--output', type=str, default='out') + args = parser.parse_args() + + main(args) diff --git a/examples/scatter.py b/examples/scatter.py index 6172351..888940c 100644 --- a/examples/scatter.py +++ b/examples/scatter.py @@ -34,122 +34,127 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import jax -import jax.numpy as jnp import jax.lax as lax +import jax.numpy as jnp from jax.lax import scan + def _chunk_split(ptcl_num, chunk_size, *arrays): - """Split and reshape particle arrays into chunks and remainders, with the remainders + """Split and reshape particle arrays into chunks and remainders, with the remainders preceding the chunks. 0D ones are duplicated as full arrays in the chunks.""" - chunk_size = ptcl_num if chunk_size is None else min(chunk_size, ptcl_num) - remainder_size = ptcl_num % chunk_size - chunk_num = ptcl_num // chunk_size + chunk_size = ptcl_num if chunk_size is None else min(chunk_size, ptcl_num) + remainder_size = ptcl_num % chunk_size + chunk_num = ptcl_num // chunk_size - remainder = None - chunks = arrays - if remainder_size: - remainder = [x[:remainder_size] if x.ndim != 0 else x for x in arrays] - chunks = [x[remainder_size:] if x.ndim != 0 else x for x in arrays] + remainder = None + chunks = arrays + if remainder_size: + remainder = [x[:remainder_size] if x.ndim != 0 else x for x in arrays] + chunks = [x[remainder_size:] if x.ndim != 0 else x for x in arrays] - # `scan` triggers errors in scatter and gather without the `full` - chunks = [x.reshape(chunk_num, chunk_size, *x.shape[1:]) if x.ndim != 0 - else jnp.full(chunk_num, x) for x in chunks] + # `scan` triggers errors in scatter and gather without the `full` + chunks = [ + x.reshape(chunk_num, chunk_size, *x.shape[1:]) + if x.ndim != 0 else jnp.full(chunk_num, x) for x in chunks + ] - return remainder, chunks + return remainder, chunks -def enmesh(i1, d1, a1, s1, b12, a2, s2): - """Multilinear enmeshing.""" - i1 = jnp.asarray(i1) - d1 = jnp.asarray(d1) - a1 = jnp.float64(a1) if a2 is not None else jnp.array(a1, dtype=d1.dtype) - if s1 is not None: - s1 = jnp.array(s1, dtype=i1.dtype) - b12 = jnp.float64(b12) - if a2 is not None: - a2 = jnp.float64(a2) - if s2 is not None: - s2 = jnp.array(s2, dtype=i1.dtype) - dim = i1.shape[1] - neighbors = (jnp.arange(2**dim, dtype=i1.dtype)[:, jnp.newaxis] - >> jnp.arange(dim, dtype=i1.dtype) - ) & 1 +def enmesh(i1, d1, a1, s1, b12, a2, s2): + """Multilinear enmeshing.""" + i1 = jnp.asarray(i1) + d1 = jnp.asarray(d1) + a1 = jnp.float64(a1) if a2 is not None else jnp.array(a1, dtype=d1.dtype) + if s1 is not None: + s1 = jnp.array(s1, dtype=i1.dtype) + b12 = jnp.float64(b12) + if a2 is not None: + a2 = jnp.float64(a2) + if s2 is not None: + s2 = jnp.array(s2, dtype=i1.dtype) + + dim = i1.shape[1] + neighbors = (jnp.arange(2**dim, dtype=i1.dtype)[:, jnp.newaxis] >> jnp.arange( + dim, dtype=i1.dtype)) & 1 + + if a2 is not None: + P = i1 * a1 + d1 - b12 + P = P[:, jnp.newaxis] # insert neighbor axis + i2 = P + neighbors * a2 # multilinear - if a2 is not None: - P = i1 * a1 + d1 - b12 - P = P[:, jnp.newaxis] # insert neighbor axis - i2 = P + neighbors * a2 # multilinear + if s1 is not None: + L = s1 * a1 + i2 %= L - if s1 is not None: - L = s1 * a1 - i2 %= L + i2 //= a2 + d2 = P - i2 * a2 - i2 //= a2 - d2 = P - i2 * a2 + if s1 is not None: + d2 -= jnp.rint(d2 / L) * L # also abs(d2) < a2 is expected - if s1 is not None: - d2 -= jnp.rint(d2 / L) * L # also abs(d2) < a2 is expected + i2 = i2.astype(i1.dtype) + d2 = d2.astype(d1.dtype) + a2 = a2.astype(d1.dtype) - i2 = i2.astype(i1.dtype) - d2 = d2.astype(d1.dtype) - a2 = a2.astype(d1.dtype) + d2 /= a2 + else: + i12, d12 = jnp.divmod(b12, a1) + i1 -= i12.astype(i1.dtype) + d1 -= d12.astype(d1.dtype) - d2 /= a2 - else: - i12, d12 = jnp.divmod(b12, a1) - i1 -= i12.astype(i1.dtype) - d1 -= d12.astype(d1.dtype) + # insert neighbor axis + i1 = i1[:, jnp.newaxis] + d1 = d1[:, jnp.newaxis] - # insert neighbor axis - i1 = i1[:, jnp.newaxis] - d1 = d1[:, jnp.newaxis] + # multilinear + d1 /= a1 + i2 = jnp.floor(d1).astype(i1.dtype) + i2 += neighbors + d2 = d1 - i2 + i2 += i1 - # multilinear - d1 /= a1 - i2 = jnp.floor(d1).astype(i1.dtype) - i2 += neighbors - d2 = d1 - i2 - i2 += i1 + if s1 is not None: + i2 %= s1 - if s1 is not None: - i2 %= s1 + f2 = 1 - jnp.abs(d2) - f2 = 1 - jnp.abs(d2) + if s1 is None and s2 is not None: # all i2 >= 0 if s1 is not None + i2 = jnp.where(i2 < 0, s2, i2) - if s1 is None and s2 is not None: # all i2 >= 0 if s1 is not None - i2 = jnp.where(i2 < 0, s2, i2) + f2 = f2.prod(axis=-1) - f2 = f2.prod(axis=-1) + return i2, f2 - return i2, f2 def scatter(pmid, disp, mesh, chunk_size=2**24, val=1., offset=0, cell_size=1.): - ptcl_num, spatial_ndim = pmid.shape - val = jnp.asarray(val) - mesh = jnp.asarray(mesh) + ptcl_num, spatial_ndim = pmid.shape + val = jnp.asarray(val) + mesh = jnp.asarray(mesh) + + remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val) - remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val) + carry = mesh, offset, cell_size + if remainder is not None: + carry = _scatter_chunk(carry, remainder)[0] + carry = scan(_scatter_chunk, carry, chunks)[0] + mesh = carry[0] - carry = mesh, offset, cell_size - if remainder is not None: - carry = _scatter_chunk(carry, remainder)[0] - carry = scan(_scatter_chunk, carry, chunks)[0] - mesh = carry[0] + return mesh - return mesh def _scatter_chunk(carry, chunk): - mesh, offset, cell_size = carry - pmid, disp, val = chunk - spatial_ndim = pmid.shape[1] - spatial_shape = mesh.shape - - # multilinear mesh indices and fractions - ind, frac = enmesh(pmid, disp, cell_size, spatial_shape, - offset, cell_size, spatial_shape) - # scatter - ind = tuple(ind[..., i] for i in range(spatial_ndim)) - mesh = mesh.at[ind].add(val * frac) - - carry = mesh, offset, cell_size - return carry, None \ No newline at end of file + mesh, offset, cell_size = carry + pmid, disp, val = chunk + spatial_ndim = pmid.shape[1] + spatial_shape = mesh.shape + + # multilinear mesh indices and fractions + ind, frac = enmesh(pmid, disp, cell_size, spatial_shape, offset, cell_size, + spatial_shape) + # scatter + ind = tuple(ind[..., i] for i in range(spatial_ndim)) + mesh = mesh.at[ind].add(val * frac) + + carry = mesh, offset, cell_size + return carry, None diff --git a/examples/submit_rusty.sbatch b/examples/submit_rusty.sbatch new file mode 100644 index 0000000..ec061ac --- /dev/null +++ b/examples/submit_rusty.sbatch @@ -0,0 +1,20 @@ +#!/bin/bash -l +#SBATCH -p gpu +#SBATCH -t 0:10:00 +#SBATCH -C a100-80gb&ib-a100 +#SBATCH -N 1 +#SBATCH --ntasks-per-node=4 +#SBATCH --cpus-per-task=16 +#SBATCH --gpus-per-task=1 + +module load modules/2.3 +module load gcc nvhpc python + +source /mnt/home/flanusse/venvs/jaxdecomp724/bin/activate +export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK +export LD_LIBRARY_PATH=$NVHPC_ROOT/Linux_x86_64/24.3/cuda/12.3/extras/CUPTI/lib64:$LD_LIBRARY_PATH +export LD_LIBRARY_PATH=$NVHPC_ROOT/Linux_x86_64/24.3/cuda/12.3/lib64:$LD_LIBRARY_PATH +export LD_LIBRARY_PATH=$NVHPC_ROOT/Linux_x86_64/24.3/comm_libs/nccl/lib:$LD_LIBRARY_PATH +export LD_LIBRARY_PATH=$NVHPC_ROOT/Linux_x86_64/24.3/math_libs/lib64:$LD_LIBRARY_PATH + +mpirun python3 lpt_nbody_demo.py --pdims 2x2 From fe60caae066adf94000fb69070d52759a9a25803 Mon Sep 17 00:00:00 2001 From: EiffL Date: Mon, 8 Jul 2024 02:06:11 -0400 Subject: [PATCH 3/9] updating implementation --- examples/scatter.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/scatter.py b/examples/scatter.py index 888940c..1b4ce2f 100644 --- a/examples/scatter.py +++ b/examples/scatter.py @@ -33,8 +33,6 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import jax -import jax.lax as lax import jax.numpy as jnp from jax.lax import scan From 608cb406d85090402968a79e178252f4708a2d6c Mon Sep 17 00:00:00 2001 From: Francois Lanusse Date: Mon, 8 Jul 2024 14:33:13 -0400 Subject: [PATCH 4/9] Update examples/lpt_nbody_demo.py Co-authored-by: Wassim KABALAN --- examples/lpt_nbody_demo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/lpt_nbody_demo.py b/examples/lpt_nbody_demo.py index 07044e7..3f40460 100644 --- a/examples/lpt_nbody_demo.py +++ b/examples/lpt_nbody_demo.py @@ -170,10 +170,10 @@ def cic_op(disp): a, b, c = jnp.meshgrid( jnp.arange(local_mesh_shape[0]), jnp.arange(local_mesh_shape[1]), - jnp.arange(local_mesh_shape[2])) + jnp.arange(local_mesh_shape[2]) , indexing='ij') # adding an offset of size halo size - pmid = jnp.stack([b + halo_size, a + halo_size, c], axis=-1) + pmid = jnp.stack([a + halo_size, b + halo_size, c], axis=-1) return scatter(pmid.reshape([-1, 3]), disp.reshape([-1, 3]), mesh) # Performs painting on padded mesh From d7e5c280c314212660b7dcc2c3aba5bb44fab48e Mon Sep 17 00:00:00 2001 From: EiffL Date: Mon, 8 Jul 2024 18:29:07 -0400 Subject: [PATCH 5/9] updating example --- examples/lpt_nbody_demo.py | 106 +++++++++++++++++++------------------ 1 file changed, 54 insertions(+), 52 deletions(-) diff --git a/examples/lpt_nbody_demo.py b/examples/lpt_nbody_demo.py index 3f40460..3da694d 100644 --- a/examples/lpt_nbody_demo.py +++ b/examples/lpt_nbody_demo.py @@ -1,7 +1,12 @@ import argparse import os from functools import partial -from typing import Tuple +from typing import Any, Callable, Hashable, Tuple + +from jax._src import mesh as mesh_lib + +Specs = Any +AxisName = Hashable import jax @@ -12,22 +17,33 @@ import numpy as np from jax.experimental import mesh_utils from jax.experimental.shard_map import shard_map -from jax.sharding import Mesh, NamedSharding from jax.sharding import PartitionSpec as P from scatter import scatter import jaxdecomp -def _global_to_local_size(mesh_shape, sharding): +def shmap(f: Callable, + in_specs: Specs, + out_specs: Specs, + check_rep: bool = True, + auto: frozenset[AxisName] = frozenset()): + """Helper function to create a shard_map function that extracts the mesh from the + context.""" + # Extracts the mesh from the context + mesh = mesh_lib.thread_resources.env.physical_mesh + return shard_map(f, mesh, in_specs, out_specs, check_rep, auto) + + +def _global_to_local_size(mesh_shape): """ Utility function to compute the expected local size of a mesh given the global size and the sharding strategy. """ - pdims = sharding.mesh.devices.shape + pdims = mesh_lib.thread_resources.env.physical_mesh.devices.shape return [mesh_shape[0] // pdims[1], mesh_shape[1] // pdims[0], mesh_shape[2]] -def fttk(nc: int, sharding) -> list: +def fttk(nc: int) -> list: """ Generate Fourier transform wave numbers for a given mesh. @@ -42,19 +58,18 @@ def fttk(nc: int, sharding) -> list: ------- list List of wave number arrays for each dimension. - """ + """ kd = np.fft.fftfreq(nc) * 2 * np.pi - return [ - jax.make_array_from_callback( - (nc, 1, 1), - sharding=NamedSharding(sharding.mesh, P('z')), - data_callback=lambda x: kd.reshape([-1, 1, 1])[x]), - jax.make_array_from_callback( - (1, nc, 1), - sharding=NamedSharding(sharding.mesh, P(None, 'y')), - data_callback=lambda x: kd.reshape([1, -1, 1])[x]), - kd.reshape([1, 1, -1]) - ] + + @partial( + shmap, + in_specs=(P('z'), P('y'), P(None)), + out_specs=(P('z'), P(None, 'y'), P(None))) + def get_kvec(kx, ky, kz): + return (kx.reshape([-1, 1, 1]), ky.reshape([1, -1, + 1]), kz.reshape([1, 1, -1])) + + return get_kvec(kd, kd, kd) def gravity_kernel(kvec): @@ -69,7 +84,7 @@ def gravity_kernel(kvec): ------- jnp.ndarray Gravitational kernel. - """ + """ kx, ky, kz = kvec kk = kx**2 + ky**2 + kz**2 laplace_kernel = jnp.where(kk == 0, 1., 1. / kk) @@ -81,7 +96,7 @@ def gravity_kernel(kvec): return grav_kernel -def gaussian_field_and_forces(key, nc, box_size, power_spectrum, sharding): +def gaussian_field_and_forces(key, nc, box_size, power_spectrum): """ Generate a Gaussian field with a given power spectrum, along with gravitational forces. @@ -104,19 +119,20 @@ def gaussian_field_and_forces(key, nc, box_size, power_spectrum, sharding): The generated Gaussian field and the gravitational forces. """ mesh_shape = (nc,) * 3 - local_mesh_shape = _global_to_local_size(mesh_shape, sharding) + local_mesh_shape = _global_to_local_size(mesh_shape) # Create a distributed field drawn from a Gaussian distribution in real space - delta = jax.make_array_from_single_device_arrays( - shape=mesh_shape, - sharding=sharding, - arrays=[jax.random.normal(key, local_mesh_shape, dtype='float32')]) + @partial(shmap, in_specs=(), out_specs=P('z', 'y')) + def _sample_gaussian(): + return jax.random.normal(key, local_mesh_shape, dtype='float32') + + delta = _sample_gaussian() # Compute the Fourier transform of the field delta_k = jaxdecomp.fft.pfft3d(delta.astype(jnp.complex64)) # Compute the Fourier wavenumbers of the field - kx, ky, kz = fttk(nc, sharding) + kx, ky, kz = fttk(nc) kk = jnp.sqrt(kx**2 + ky**2 + kz**2) * (nc / box_size)**3 # Apply power spectrum to Fourier modes @@ -135,7 +151,7 @@ def gaussian_field_and_forces(key, nc, box_size, power_spectrum, sharding): return delta, forces -def cic_paint(displacement, sharding, halo_size): +def cic_paint(displacement, halo_size): """ Paints particles on a mesh using Cloud-In-Cell interpolation. Parameters @@ -152,13 +168,9 @@ def cic_paint(displacement, sharding, halo_size): jnp.ndarray Density field. """ - local_mesh_shape = _global_to_local_size(displacement.shape, sharding) + local_mesh_shape = _global_to_local_size(displacement.shape) - @partial( - shard_map, - mesh=sharding.mesh, - in_specs=(P('z', 'y'),), - out_specs=P('z', 'y')) + @partial(shmap, in_specs=(P('z', 'y'),), out_specs=P('z', 'y')) def cic_op(disp): """ CiC operation on each local slice of the mesh.""" # Create a mesh to paint the particles on for the local slice @@ -169,8 +181,10 @@ def cic_op(disp): [[halo_size, halo_size], [halo_size, halo_size], [0, 0]]) a, b, c = jnp.meshgrid( - jnp.arange(local_mesh_shape[0]), jnp.arange(local_mesh_shape[1]), - jnp.arange(local_mesh_shape[2]) , indexing='ij') + jnp.arange(local_mesh_shape[0]), + jnp.arange(local_mesh_shape[1]), + jnp.arange(local_mesh_shape[2]), + indexing='ij') # adding an offset of size halo size pmid = jnp.stack([a + halo_size, b + halo_size, c], axis=-1) @@ -185,11 +199,7 @@ def cic_op(disp): halo_extents=(halo_size // 2, halo_size // 2, 0), halo_periods=(True, True, True)) - @partial( - shard_map, - mesh=sharding.mesh, - in_specs=(P('z', 'y'),), - out_specs=P('z', 'y')) + @partial(shmap, in_specs=(P('z', 'y'),), out_specs=P('z', 'y')) def unpad(x): """ Unpad the output array. """ x = x.at[halo_size:halo_size + halo_size // 2].add(x[:halo_size // 2]) @@ -205,7 +215,8 @@ def unpad(x): return field -def simulation_fn(key, nc, box_size, sharding, halo_size, a=1.0): +@partial(jax.jit, static_argnames=('nc', 'box_size', 'halo_size')) +def simulation_fn(key, nc, box_size, halo_size, a=1.0): """ Run a simulation to generate initial conditions and density field using LPT. @@ -240,11 +251,7 @@ def simulation_fn(key, nc, box_size, sharding, halo_size, a=1.0): # Generate a Gaussian field and gravitational forces from a power spectrum intial_conditions, initial_forces = gaussian_field_and_forces( - key=key, - nc=nc, - box_size=box_size, - power_spectrum=pk_fn, - sharding=sharding) + key=key, nc=nc, box_size=box_size, power_spectrum=pk_fn) # Compute the LPT displacement of that particles initialy placed on a regular grid # would experience at scale factor a, by simple Zeldovich approximation @@ -252,7 +259,7 @@ def simulation_fn(key, nc, box_size, sharding, halo_size, a=1.0): cosmology, jnp.atleast_1d(a)) * initial_forces # Paints the displaced particles on a mesh to obtain the density field - final_field = cic_paint(initial_displacement, sharding, halo_size) + final_field = cic_paint(initial_displacement, halo_size) return intial_conditions, final_field @@ -273,16 +280,11 @@ def main(args): pdims = tuple(map(int, args.pdims.split('x'))) devices = mesh_utils.create_device_mesh(pdims) mesh = Mesh(devices, axis_names=('y', 'z')) - sharding = jax.sharding.NamedSharding(mesh, P('z', 'y')) with mesh: # Run the simulation on the compute mesh initial_conds, final_field = simulation_fn( - key=key, - nc=args.nc, - box_size=args.box_size, - sharding=sharding, - halo_size=args.halo_size) + key=key, nc=args.nc, box_size=args.box_size, halo_size=args.halo_size) # Create output directory to save the results output_dir = args.output From 9ad2f701365354cbf0148f5dca2cb565d6b89f9c Mon Sep 17 00:00:00 2001 From: EiffL Date: Mon, 8 Jul 2024 18:43:58 -0400 Subject: [PATCH 6/9] add back missing import --- examples/lpt_nbody_demo.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/lpt_nbody_demo.py b/examples/lpt_nbody_demo.py index 3da694d..b81c8e7 100644 --- a/examples/lpt_nbody_demo.py +++ b/examples/lpt_nbody_demo.py @@ -17,6 +17,7 @@ import numpy as np from jax.experimental import mesh_utils from jax.experimental.shard_map import shard_map +from jax.sharding import Mesh from jax.sharding import PartitionSpec as P from scatter import scatter @@ -279,7 +280,7 @@ def main(args): # Create computing mesh and sharding information pdims = tuple(map(int, args.pdims.split('x'))) devices = mesh_utils.create_device_mesh(pdims) - mesh = Mesh(devices, axis_names=('y', 'z')) + mesh = Mesh(devices, axis_names=('y', 'x')) with mesh: # Run the simulation on the compute mesh From b1b6b73bae1925e5dcd05097c2b2c822905a9215 Mon Sep 17 00:00:00 2001 From: EiffL Date: Mon, 8 Jul 2024 21:12:36 -0400 Subject: [PATCH 7/9] clean up demo --- examples/lpt_nbody_demo.py | 211 ++++++++++++++--------------------- examples/submit_rusty.sbatch | 4 +- 2 files changed, 88 insertions(+), 127 deletions(-) diff --git a/examples/lpt_nbody_demo.py b/examples/lpt_nbody_demo.py index b81c8e7..aeb7dc8 100644 --- a/examples/lpt_nbody_demo.py +++ b/examples/lpt_nbody_demo.py @@ -1,9 +1,7 @@ import argparse import os from functools import partial -from typing import Any, Callable, Hashable, Tuple - -from jax._src import mesh as mesh_lib +from typing import Any, Callable, Hashable Specs = Any AxisName = Hashable @@ -15,6 +13,7 @@ import jax.numpy as jnp import jax_cosmo as jc import numpy as np +from jax._src import mesh as mesh_lib from jax.experimental import mesh_utils from jax.experimental.shard_map import shard_map from jax.sharding import Mesh @@ -31,69 +30,59 @@ def shmap(f: Callable, auto: frozenset[AxisName] = frozenset()): """Helper function to create a shard_map function that extracts the mesh from the context.""" - # Extracts the mesh from the context mesh = mesh_lib.thread_resources.env.physical_mesh return shard_map(f, mesh, in_specs, out_specs, check_rep, auto) -def _global_to_local_size(mesh_shape): - """ Utility function to compute the expected local size of a mesh - given the global size and the sharding strategy. +def _global_to_local_size(nc: int): + """ Helper function to get the local size of a mesh given the global size. """ pdims = mesh_lib.thread_resources.env.physical_mesh.devices.shape - return [mesh_shape[0] // pdims[1], mesh_shape[1] // pdims[0], mesh_shape[2]] + return [nc // pdims[0], nc // pdims[1], nc] def fttk(nc: int) -> list: """ Generate Fourier transform wave numbers for a given mesh. - Parameters - ---------- - mesh_shape : int - Shape of the mesh grid. - sharding : Any - Sharding strategy for the array. - - Returns - ------- - list - List of wave number arrays for each dimension. + Args: + nc (int): Shape of the mesh grid. + + Returns: + list: List of wave number arrays for each dimension. """ kd = np.fft.fftfreq(nc) * 2 * np.pi @partial( shmap, - in_specs=(P('z'), P('y'), P(None)), - out_specs=(P('z'), P(None, 'y'), P(None))) + in_specs=(P('x'), P('y'), P(None)), + out_specs=(P('x'), P(None, 'y'), P(None))) def get_kvec(kx, ky, kz): - return (kx.reshape([-1, 1, 1]), ky.reshape([1, -1, - 1]), kz.reshape([1, 1, -1])) + return (kx.reshape([-1, 1, 1]), + ky.reshape([1, -1, 1]), + kz.reshape([1, 1, -1])) # yapf: disable return get_kvec(kd, kd, kd) def gravity_kernel(kvec): - """ Fourier kernel to compute gravitational forces from a Fourier space density field. + """ Computes a Fourier kernel combining laplace and derivative + operators to compute gravitational forces. - Parameters - ---------- - kvec : tuple of float - Wave vector in Fourier space. + Args: + kvec (tuple of float): Wave numbers in Fourier space. - Returns - ------- - jnp.ndarray - Gravitational kernel. + Returns: + tuple of jnp.ndarray: kernels for each dimension. """ kx, ky, kz = kvec kk = kx**2 + ky**2 + kz**2 laplace_kernel = jnp.where(kk == 0, 1., 1. / kk) - grav_kernel = [ - laplace_kernel * 1j / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)), - laplace_kernel * 1j / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)), - laplace_kernel * 1j / 6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)) - ] + # Note that we return frequency arrays in the transposed order [z, x, y] + # corresponding to the transposed FFT output + grav_kernel = (laplace_kernel * 1j * kz, + laplace_kernel * 1j * kx, + laplace_kernel * 1j * ky) # yapf: disable return grav_kernel @@ -101,40 +90,29 @@ def gaussian_field_and_forces(key, nc, box_size, power_spectrum): """ Generate a Gaussian field with a given power spectrum, along with gravitational forces. - Parameters - ---------- - key : int - key for the random number generator. - nc : int - Number of cells in the mesh. - box_size : float - Size of the box. - power_spectrum : callable - Power spectrum function. - sharding : Any - Sharding strategy for the array. - - Returns - ------- - delta, forces : tuple of jnp.ndarray - The generated Gaussian field and the gravitational forces. - """ - mesh_shape = (nc,) * 3 - local_mesh_shape = _global_to_local_size(mesh_shape) + Args: + key (int): Key for the random number generator. + nc (int): Number of cells in the mesh. + box_size (float): Size of the box. + power_spectrum (callable): Power spectrum function. - # Create a distributed field drawn from a Gaussian distribution in real space - @partial(shmap, in_specs=(), out_specs=P('z', 'y')) - def _sample_gaussian(): - return jax.random.normal(key, local_mesh_shape, dtype='float32') + Returns: + tuple of jnp.ndarray: The generated Gaussian field and the gravitational forces. + """ + local_mesh_shape = _global_to_local_size(nc) - delta = _sample_gaussian() + # Create a distributed field drawn from a Gaussian distribution in real space + delta = shmap( + partial(jax.random.normal, shape=local_mesh_shape, dtype='float32'), + in_specs=P(None), + out_specs=P('x', 'y'))(key) # yapf: disable # Compute the Fourier transform of the field delta_k = jaxdecomp.fft.pfft3d(delta.astype(jnp.complex64)) # Compute the Fourier wavenumbers of the field kx, ky, kz = fttk(nc) - kk = jnp.sqrt(kx**2 + ky**2 + kz**2) * (nc / box_size)**3 + kk = jnp.sqrt(kx**2 + ky**2 + kz**2) * (nc / box_size) # Apply power spectrum to Fourier modes delta_k *= (power_spectrum(kk) * (nc / box_size)**3)**0.5 @@ -155,61 +133,57 @@ def _sample_gaussian(): def cic_paint(displacement, halo_size): """ Paints particles on a mesh using Cloud-In-Cell interpolation. - Parameters - ---------- - displacement : jnp.ndarray - Displacement field of particles. - sharding : Any - Sharding strategy for the array. - halo_size : int - Halo size for painting. - - Returns - ------- - jnp.ndarray - Density field. + Args: + displacement (jnp.ndarray): Displacement of each particle. + halo_size (int): Halo size for painting. + + Returns: + jnp.ndarray: Density field. """ - local_mesh_shape = _global_to_local_size(displacement.shape) + local_mesh_shape = _global_to_local_size(displacement.shape[0]) + hs = halo_size - @partial(shmap, in_specs=(P('z', 'y'),), out_specs=P('z', 'y')) + @partial(shmap, in_specs=(P('x', 'y'),), out_specs=P('x', 'y')) def cic_op(disp): """ CiC operation on each local slice of the mesh.""" # Create a mesh to paint the particles on for the local slice mesh = jnp.zeros(disp.shape[:-1], dtype='float32') - # Padding the output array along the two first dimensions - mesh = jnp.pad(mesh, - [[halo_size, halo_size], [halo_size, halo_size], [0, 0]]) + # Padding the mesh along the two first dimensions + mesh = jnp.pad(mesh, [[hs, hs], [hs, hs], [0, 0]]) - a, b, c = jnp.meshgrid( + # Compute the position of the particles on a regular grid + pos_x, pos_y, pos_z = jnp.meshgrid( jnp.arange(local_mesh_shape[0]), jnp.arange(local_mesh_shape[1]), jnp.arange(local_mesh_shape[2]), indexing='ij') # adding an offset of size halo size - pmid = jnp.stack([a + halo_size, b + halo_size, c], axis=-1) - return scatter(pmid.reshape([-1, 3]), disp.reshape([-1, 3]), mesh) + pos = jnp.stack([pos_x + hs, pos_y + hs, pos_z], axis=-1) + + # Apply scatter operation to paint the particles on the local mesh + field = scatter(pos.reshape([-1, 3]), disp.reshape([-1, 3]), mesh) - # Performs painting on padded mesh + return field + + # Performs painting on a padded mesh, with halos on the two first dimensions field = cic_op(displacement) # Run halo exchange to get the correct values at the boundaries field = jaxdecomp.halo_exchange( field, - halo_extents=(halo_size // 2, halo_size // 2, 0), + halo_extents=(hs // 2, hs // 2, 0), halo_periods=(True, True, True)) - @partial(shmap, in_specs=(P('z', 'y'),), out_specs=P('z', 'y')) + @partial(shmap, in_specs=(P('x', 'y'),), out_specs=P('x', 'y')) def unpad(x): - """ Unpad the output array. """ - x = x.at[halo_size:halo_size + halo_size // 2].add(x[:halo_size // 2]) - x = x.at[-(halo_size + halo_size // 2):-halo_size].add(x[-halo_size // 2:]) - x = x.at[:, halo_size:halo_size + halo_size // 2].add(x[:, :halo_size // 2]) - x = x.at[:, - -(halo_size + halo_size // 2):-halo_size].add(x[:, - -halo_size // 2:]) - return x[halo_size:-halo_size, halo_size:-halo_size, :] + """ Removes the padding and reduce the halo regions""" + x = x.at[hs:hs + hs // 2].add(x[:hs // 2]) + x = x.at[-(hs + hs // 2):-hs].add(x[-hs // 2:]) + x = x.at[:, hs:hs + hs // 2].add(x[:, :hs // 2]) + x = x.at[:, -(hs + hs // 2):-hs].add(x[:, -hs // 2:]) + return x[hs:-hs, hs:-hs, :] # Unpad the output array field = unpad(field) @@ -221,32 +195,22 @@ def simulation_fn(key, nc, box_size, halo_size, a=1.0): """ Run a simulation to generate initial conditions and density field using LPT. - Parameters - ---------- - key : list of int - Jax random key for the random number generator. - nc : int - Size of the mesh grid. - box_size : float - Size of the box. - sharding : Any - Sharding strategy for the simulation. - halo_size: int - Halo size for painting. - a : float - Scale factor of final field. - - Returns - ------- - initial_conditions, field : tuple of jnp.ndarray - Initial conditions and the density field. + Args: + key (list of int): Jax random key for the random number generator. + nc (int): Size of the mesh grid. + box_size (float): Size of the box. + halo_size (int): Halo size for painting. + a (float): Scale factor of final field. + + Returns: + tuple of jnp.ndarray: Initial conditions and final density field. """ # Build a default cosmology - cosmology = jc.Planck15() + cosmo = jc.Planck15() # Create a small function to generate the linear matter power spectrum at arbitrary k k = jnp.logspace(-4, 1, 128) - pk = jc.power.linear_matter_power(cosmology, k) + pk = jc.power.linear_matter_power(cosmo, k) pk_fn = jax.jit(lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk). reshape(x.shape)) @@ -254,10 +218,10 @@ def simulation_fn(key, nc, box_size, halo_size, a=1.0): intial_conditions, initial_forces = gaussian_field_and_forces( key=key, nc=nc, box_size=box_size, power_spectrum=pk_fn) - # Compute the LPT displacement of that particles initialy placed on a regular grid + # Compute the LPT displacement that particles initialy placed on a regular grid # would experience at scale factor a, by simple Zeldovich approximation initial_displacement = jc.background.growth_factor( - cosmology, jnp.atleast_1d(a)) * initial_forces + cosmo, jnp.atleast_1d(a)) * initial_forces # Paints the displaced particles on a mesh to obtain the density field final_field = cic_paint(initial_displacement, halo_size) @@ -280,10 +244,10 @@ def main(args): # Create computing mesh and sharding information pdims = tuple(map(int, args.pdims.split('x'))) devices = mesh_utils.create_device_mesh(pdims) - mesh = Mesh(devices, axis_names=('y', 'x')) + mesh = Mesh(devices.T, axis_names=('x', 'y')) + # Run the simulation on the compute mesh with mesh: - # Run the simulation on the compute mesh initial_conds, final_field = simulation_fn( key=key, nc=args.nc, box_size=args.box_size, halo_size=args.halo_size) @@ -300,16 +264,13 @@ def main(args): if __name__ == '__main__': - parser = argparse.ArgumentParser("LPT N-body simulation with JAX.") + parser = argparse.ArgumentParser("Distributed LPT N-body simulation.") parser.add_argument( '--pdims', type=str, default='1x1', help="Processor grid dimensions") parser.add_argument( '--nc', type=int, default=256, help="Number of cells in the mesh") parser.add_argument( - '--box_size', - type=float, - default=256., - help="Size of the simulation box in Mpc/h") + '--box_size', type=float, default=512., help="Box size in Mpc/h") parser.add_argument( '--halo_size', type=int, default=32, help="Halo size for painting") parser.add_argument('--output', type=str, default='out') diff --git a/examples/submit_rusty.sbatch b/examples/submit_rusty.sbatch index ec061ac..127bb30 100644 --- a/examples/submit_rusty.sbatch +++ b/examples/submit_rusty.sbatch @@ -1,7 +1,7 @@ #!/bin/bash -l #SBATCH -p gpu #SBATCH -t 0:10:00 -#SBATCH -C a100-80gb&ib-a100 +#SBATCH -C a100 #SBATCH -N 1 #SBATCH --ntasks-per-node=4 #SBATCH --cpus-per-task=16 @@ -10,7 +10,7 @@ module load modules/2.3 module load gcc nvhpc python -source /mnt/home/flanusse/venvs/jaxdecomp724/bin/activate +source ~/venvs/jaxdecomp724/bin/activate export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK export LD_LIBRARY_PATH=$NVHPC_ROOT/Linux_x86_64/24.3/cuda/12.3/extras/CUPTI/lib64:$LD_LIBRARY_PATH export LD_LIBRARY_PATH=$NVHPC_ROOT/Linux_x86_64/24.3/cuda/12.3/lib64:$LD_LIBRARY_PATH From e238a1dea1538c7a35a92824ed0baa35657c87b4 Mon Sep 17 00:00:00 2001 From: EiffL Date: Mon, 8 Jul 2024 21:23:02 -0400 Subject: [PATCH 8/9] update the notebook and adding a readme --- examples/README.md | 23 +++++++++++++++++++++++ examples/visualizer.ipynb | 4 ++-- 2 files changed, 25 insertions(+), 2 deletions(-) create mode 100644 examples/README.md diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000..225c183 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,23 @@ +# Use-Case Examples + +This directory contains examples of how to use the jaxDecomp library on a few use cases. + +## Distributed LPT Cosmological Simulation + +This example demonstrates the use of the 3D distributed FFT and halo exchange functions in the `jaxDecomp` library to implement a distributed LPT cosmological simulation. We provide a notebook to visualize the results of the simulation in [visualizer.ipynb](visualizer.ipynb). + +To run the demo, some additional dependencies are required. You can install them by running: + +```bash +pip install jax-cosmo +``` + +Then, you can run the example by executing the following command: +```bash +mpirun -n 4 python lpt_nbody_demo.py --nc 256 --box_size 256 --pdims 4x4 --halo_size 32 --output out +``` + +We also include an example of a slurm script in [submit_rusty.sbatch](submit_rusty.sbatch) that can be used to run the example on a slurm cluster with: +```bash +sbatch submit_rusty.sbatch +``` diff --git a/examples/visualizer.ipynb b/examples/visualizer.ipynb index 85b7434..122e286 100644 --- a/examples/visualizer.ipynb +++ b/examples/visualizer.ipynb @@ -29,7 +29,7 @@ "\n", "\n", "```sh\n", - "mpirun -n 16 python jaxdecomp_lpt.py -s 2048 -b 2048 -p 4x4 -hs 256 -o out\n", + "mpirun -n 16 python lpt_nbody_demo.py --nc 2048 --box_size 2048 --pdims 4x4 --halo_size 256 --output out\n", "```\n", "\n", "Or if you have a slurm cluster, you can use the following command:\n", @@ -37,7 +37,7 @@ "\n", "```sh\n", "salloc --nodes=2 --ntasks=16 --cpus-per-task=1 --gres=gpu:8 --time=00:30:00\n", - "srun python jaxdecomp_lpt.py -s 2048 -b 2048 -p 4x4 -hs 256 -o out\n", + "srun python lpt_nbody_demo.py --nc 2048 --box_size 2048 --pdims 4x4 --halo_size 256 --output out\n", "``\n" ] }, From 699fdb8d5a647e821a6f7b3e8afc364a573d2399 Mon Sep 17 00:00:00 2001 From: EiffL Date: Tue, 9 Jul 2024 12:50:11 -0400 Subject: [PATCH 9/9] apply cleanup of axis ordering and naming --- examples/lpt_nbody_demo.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/examples/lpt_nbody_demo.py b/examples/lpt_nbody_demo.py index aeb7dc8..bc15730 100644 --- a/examples/lpt_nbody_demo.py +++ b/examples/lpt_nbody_demo.py @@ -49,7 +49,8 @@ def fttk(nc: int) -> list: nc (int): Shape of the mesh grid. Returns: - list: List of wave number arrays for each dimension. + list: List of wave number arrays for each dimension in + the order [kx, ky, kz]. """ kd = np.fft.fftfreq(nc) * 2 * np.pi @@ -57,15 +58,17 @@ def fttk(nc: int) -> list: shmap, in_specs=(P('x'), P('y'), P(None)), out_specs=(P('x'), P(None, 'y'), P(None))) - def get_kvec(kx, ky, kz): - return (kx.reshape([-1, 1, 1]), - ky.reshape([1, -1, 1]), - kz.reshape([1, 1, -1])) # yapf: disable + def get_kvec(ky, kz, kx): + return (ky.reshape([-1, 1, 1]), + kz.reshape([1, -1, 1]), + kx.reshape([1, 1, -1])) # yapf: disable + ky, kz, kx = get_kvec(kd, kd, kd) # The order of the output + # corresponds to the order of dimensions in the transposed FFT + # output + return kx, ky, kz - return get_kvec(kd, kd, kd) - -def gravity_kernel(kvec): +def gravity_kernel(kx, ky, kz): """ Computes a Fourier kernel combining laplace and derivative operators to compute gravitational forces. @@ -75,14 +78,12 @@ def gravity_kernel(kvec): Returns: tuple of jnp.ndarray: kernels for each dimension. """ - kx, ky, kz = kvec kk = kx**2 + ky**2 + kz**2 laplace_kernel = jnp.where(kk == 0, 1., 1. / kk) - # Note that we return frequency arrays in the transposed order [z, x, y] - # corresponding to the transposed FFT output - grav_kernel = (laplace_kernel * 1j * kz, - laplace_kernel * 1j * kx, - laplace_kernel * 1j * ky) # yapf: disable + + grav_kernel = (laplace_kernel * 1j * kx, + laplace_kernel * 1j * ky, + laplace_kernel * 1j * kz) # yapf: disable return grav_kernel @@ -121,7 +122,7 @@ def gaussian_field_and_forces(key, nc, box_size, power_spectrum): delta = jaxdecomp.fft.pifft3d(delta_k).real # Compute gravitational forces associated with this field - grav_kernel = gravity_kernel([kx, ky, kz]) + grav_kernel = gravity_kernel(kx, ky, kz) forces_k = [g * delta_k for g in grav_kernel] # Retrieve the forces in real space by inverse Fourier transforming