# Mesh calculations

This package includes helper routines to handle mesh-based calculations. FFTs and halo padding are performed with [jaxdecomp](https://github.com/DifferentiableUniverseInitiative/jaxDecomp).

## Mesh attributes (`MeshAttrs`)

In [1]:
# To define a mesh, let's start with MeshAttrs
import jax
from jax import numpy as jnp
from jax import random
from jaxpower import MeshAttrs, RealMeshField, ComplexMeshField

# This specifies the meshvsize (128), box size (in physical length units, e.g. Mpc/h), 100., and center of the box w.r.t. the observer
attrs = MeshAttrs(meshsize=128, boxsize=100., boxcenter=0.)
print(attrs)
attrs.meshsize  # mesh size (array of length 3)
attrs.boxsize   # box size (array of length 3)
attrs.boxcenter # box center (array of length 3)
attrs.cellsize  # Cell-size
attrs.knyq  # Nyquist frequency (array of length 3), jnp.pi / attrs.cellsize
attrs.kfun  # Fundamental frequency (array of length 3), 2 * jnp.pi / attrs.boxsize

attrs.rcoords(kind='position', sparse=None)  # return mesh coordinates  (list of 3 1D arrays)
attrs.kcoords(kind='wavenumber', sparse=True);  # return Fourier-space coordinates  (list of 3 broadcastable arrays)

MeshAttrs(meshsize=staticarray([128, 128, 128], dtype=int32), boxsize=Array([100., 100., 100.], dtype=float32), boxcenter=Array([0., 0., 0.], dtype=float32), dtype=dtype('float32'), fft_engine='jaxdecomp')


In [2]:
# As a default, mesh is 3D, but it can be any dimension, as long as it specified in one of the input arrays

attrs2 = MeshAttrs(meshsize=(128, 100), boxsize=100.)
print(attrs2)
print('dimension is {:d}'.format(attrs.ndim))

MeshAttrs(meshsize=staticarray([128, 100], dtype=int32), boxsize=Array([100., 100.], dtype=float32), boxcenter=Array([0., 0.], dtype=float32), dtype=dtype('float32'), fft_engine='jax')
dimension is 3


## Create a mesh

In [3]:
# To create an (empty) real mesh
rmesh = attrs.create(kind='real')
print(type(rmesh))
# A complex mesh with 0
cmesh = attrs.create(kind='complex', fill=0.)
print(type(cmesh))

<class 'jaxpower.mesh.RealMeshField'>
<class 'jaxpower.mesh.ComplexMeshField'>


In [4]:
# If you already have an array
key = random.key(42)
array = random.uniform(key, shape=(128,) * 3)
rmesh = attrs.create(kind='real', fill=array)
# Or
rmesh = RealMeshField(array, attrs=attrs)

In [5]:
# MeshAttrs can be accessed with
attrs = rmesh.attrs

In [6]:
rmesh.coords()  # equivalent to rmesh.attrs.rcoords
cmesh.coords()  # equivalent to cmesh.attrs.kcoords
rmesh = rmesh.clone(value=2. * rmesh.value)  # update mesh value

## FFT

In [7]:
# As a default, 3D FFTs are performed with jaxdecomp if installed

cmesh = rmesh.r2c()
rmesh2 = cmesh.c2r()

assert jnp.allclose(rmesh2.value, rmesh.value, rtol=1e-4, atol=1e-4)

## Kernels

In [8]:
cmesh = cmesh.apply(lambda value, kvec: value * jnp.exp(-sum(kk**2 for kk in kvec)), kind='wavenumber')
# Also, some pre-registed kernels
from jaxpower import kernels
cmesh = cmesh.apply(kernels.gradient(axis=0))
cmesh = cmesh.apply(kernels.invlaplace())

## Reading

In [9]:
# Uniformly-distributed positions
positions = attrs.boxsize * random.uniform(key, (int(1e3), attrs.ndim)) - attrs.boxsize / 2. + attrs.boxcenter

# resampler is 'ngp', 'cic', 'tsc', 'pcs'
# compensate=True to apply compensation kernel (in Fourier space) before reading
values = rmesh.read(positions, resampler='tsc', compensate=True)

## Painting

In [10]:
from jaxpower import ParticleField

particles = ParticleField(positions, weights=jnp.ones(positions.shape[0]), attrs=attrs)
# Return painted real mesh, with interlacing and kernel compensation
rmesh = particles.paint(resampler='tsc', interlacing=2, compensate=True, out='real')

## Distributed calculation

In [11]:
# Initialize JAX distributed environment
#jax.distributed.initialize()

# Let's simulate distributed calculation
from jaxpower import create_sharding_mesh, create_sharded_random, exchange_particles

with create_sharding_mesh() as sharding_mesh:  # specify how to spatially distribute particles / mesh
    print('Sharding mesh {}.'.format(sharding_mesh))
    # To create a distributed Gaussian field
    rmesh = attrs.create(kind='real', fill=create_sharded_random(random.normal, random.key(42), shape=attrs.meshsize))
    # FFT (3D-only!)
    cmesh = rmesh.r2c()
    # kernels are automatically distributed
    cmesh.coords()  # output coords (k here) is sharded, so can be readily used
    # Now the hard part is painting / reading particles
    # Let's assume one can create particles

    def sample(key, shape):
        return attrs.boxsize * random.uniform(key, shape + (len(attrs.boxsize),), dtype=attrs.dtype) - attrs.boxsize / 2. + attrs.boxcenter
    # This will create a sample of randomly generated particles, as a sharded array
    positions = create_sharded_random(sample, random.key(42), shape=1000, out_specs=0)
    
    # These particles must be redistributed ("exchanged"), such that each portion of the particles spatially corresponds to the local portion of the mesh
    positions, exchange = exchange_particles(attrs, positions=positions, return_inverse=False)
    # exchange can be used to exchange additional arrays, such as weights: weights = exchange(weights)

    # Now we can paint the particles
    particles = ParticleField(positions, attrs=attrs)
    # Note: a shortcut to exchange_particles(...) and ParticleField(...) is
    # particles = ParticleField(positions, attrs=attrs, exchange=True)
    rmesh = particles.paint(resampler='tsc', interlacing=2, compensate=True, out='real')
    # rmesh is sharded
    # Same to read
    values = rmesh.read(positions, resampler='tsc', compensate=True)

    # For paint and read above, halo size (region of the mesh that are exchanged) was just the size of the painting / reading kernel
    # This is because particles were exchanged to exactly match the local portion of the mesh
    # In practice however, for differentiable LPT or PM schemes, you probably want to fix once for all the halo size to the maximum distance
    # travelled by the particles, typically of the order of a few Mpc/h in cosmology. So one would do:
    halo_size = int(jnp.ceil(20 / attrs.cellsize[0]))  # number of cells corresponding to 20 Mpc/h
    rmesh = particles.paint(resampler='tsc', interlacing=2, compensate=True, out='real', halo_size=halo_size)
    values = rmesh.read(positions, resampler='tsc', compensate=True, halo_size=halo_size)

# Close JAX distributed environment
#jax.distributed.shutdown()

Sharding mesh Mesh('x': 1, 'y': 1).
