Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proposal for simplified demo #24

Merged
merged 11 commits into from
Jul 9, 2024
Merged

Proposal for simplified demo #24

merged 11 commits into from
Jul 9, 2024

Conversation

EiffL
Copy link
Member

@EiffL EiffL commented Jul 8, 2024

I'd like to propose a simplification of the demo script, I removed a bunch of stuff that wasn't striclty necessary and tried to reorganize it around logical functions.
So for instance, this:

    kfield = jaxdecomp.fft.pfft3d(z.astype(jnp.complex64))

    ky, kz, kx = kvec
    kk = jnp.sqrt((kx / box_size[0] * mesh_shape[0])**2 +
                  (ky / box_size[1] * mesh_shape[1])**2 +
                  (kz / box_size[1] * mesh_shape[1])**2)

    delta_k = interpolate(kfield, kk)

    # Inverse Fourier transform to generate the initial conditions
    initial_conditions = jaxdecomp.fft.pifft3d(delta_k).real

    ###  Compute LPT displacement
    cosmo = jc.Planck15()
    a = jnp.atleast_1d(a)

    kernel_lap = jnp.where(kk == 0, 1., 1. / -(kx**2 + ky**2 + kz**2))

    pot_k = delta_k * kernel_lap
    # Forces have to be a Z pencil because they are going to be IFFT back to X pencil
    forces_k = -jnp.stack([
        pot_k * 1j / 6.0 *
        (8 * jnp.sin(kx) - jnp.sin(2 * kx)), pot_k * 1j / 6.0 *
        (8 * jnp.sin(ky) - jnp.sin(2 * ky)), pot_k * 1j / 6.0 *
        (8 * jnp.sin(kz) - jnp.sin(2 * kz))
    ],
                          axis=-1)

    init_force = jnp.stack(
        [jaxdecomp.fft.pifft3d(forces_k[..., i]).real for i in range(3)],
        axis=-1)

    dx = growth_factor(cosmo, a) * init_force

    p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo,
                                                                   a)) * dx
    f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dGfa(cosmo,
                                                             a) * init_force

becomes:

def gaussian_field_and_forces(key, nc, box_size, power_spectrum, sharding):
  mesh_shape = (nc,) * 3
  local_mesh_shape = _global_to_local_size(mesh_shape, sharding)

  # Create a distributed field drawn from a Gaussian distribution in real space
  delta = jax.make_array_from_single_device_arrays(
      shape=mesh_shape,
      sharding=sharding,
      arrays=[jax.random.normal(key, local_mesh_shape, dtype='float32')])

  # Compute the Fourier transform of the field
  delta_k = jaxdecomp.fft.pfft3d(delta.astype(jnp.complex64))

  # Compute the Fourier wavenumbers of the field
  kx, ky, kz = fttk(nc, sharding)
  kk = jnp.sqrt(kx**2 + ky**2 + kz**2) * (nc / box_size)**3

  # Apply power spectrum to Fourier modes
  delta_k *= (power_spectrum(kk) * (nc / box_size)**3)**0.5

  # Compute inverse Fourier transform to recover the initial conditions in real space
  delta = jaxdecomp.fft.pifft3d(delta_k).real

  # Compute gravitational forces associated with this field
  grav_kernel = gravity_kernel([kx, ky, kz])
  forces_k = [g * delta_k for g in grav_kernel]

  # Retrieve the forces in real space by inverse Fourier transforming
  forces = jnp.stack([jaxdecomp.fft.pifft3d(f).real for f in forces_k], axis=-1)
  
  return delta, forces

and then the main simulation function becomes:

def simulation_fn(key, nc, box_size, sharding, halo_size, a=1.0):
  # Build a default cosmology
  cosmology = jc.Planck15()

  # Create a small function to generate the linear matter power spectrum at arbitrary k
  k = jnp.logspace(-4, 1, 128)
  pk = jc.power.linear_matter_power(cosmology, k)
  pk_fn = jax.jit(lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk).
                  reshape(x.shape))

  # Generate a Gaussian field and gravitational forces from a power spectrum
  intial_conditions, initial_forces = gaussian_field_and_forces(
      key=key,
      nc=nc,
      box_size=box_size,
      power_spectrum=pk_fn,
      sharding=sharding)

  # Compute the LPT displacement of that particles initialy placed on a regular grid
  # would experience at scale factor a, by simple Zeldovich approximation
  initial_displacement = jc.background.growth_factor(
      cosmology, jnp.atleast_1d(a)) * initial_forces

  # Paints the displaced particles on a mesh to obtain the density field
  final_field = cic_paint(initial_displacement, sharding, halo_size)

  return intial_conditions, final_field

I'd be happy to streamline it even further, to make it as simple as possible, in particular around the cic_paint function, which is not easy to understand. And also, this recoded example doesn't actually give a good result, I'm not sure why.

But take a look and let me know what you think.

@EiffL EiffL requested a review from ASKabalan July 8, 2024 06:22
examples/lpt_nbody_demo.py Outdated Show resolved Hide resolved
examples/lpt_nbody_demo.py Outdated Show resolved Hide resolved
@ASKabalan
Copy link
Collaborator

ASKabalan commented Jul 8, 2024

I fixed the issue in your example.
But there is a bigger issue with the fact that your function cannot be jitted.

1 - You used NamedSharding as an argument .. NamedSharding cannot be traced
2 - You construct ShardMapped function using tracers (not allowed)
3 - You called make_array_from_callback inside a what supposed to be a jitted function (not allowed, closing on global non addressable arrays. see google/jax#22218)

I think in the latest version of JAX, the JAX team allows doing non jitted operations on non-addressable arrays inside a context mesh (basicaly jax.config.spmd_mode('allow_all') is activated by default)

Since this is just an example, I think it is ok.
is it ok for you @EiffL

For JaxPM, I solved these issues.

@ASKabalan
Copy link
Collaborator

Also, in my opinion your version of cic_paint is more straight forward than the pmwd one .
Since we don't care about optimisation and memory.. maybe it makes more sense to use your version.
But I don't mind using pmwd version.

Co-authored-by: Wassim KABALAN <wastondev@gmail.com>
@EiffL
Copy link
Member Author

EiffL commented Jul 8, 2024

Cool, thanks for the fixes!

Yeah.... I know the limitations of this approach... but I'd like to keep things in terms of functions....

What would you think of a mechanism like what we do in jaxpm, where we have a non-jitted function that returns a function that can be jitted? Like:

sim_fn = make_simulation_fn(mesh_shape,...., sharding)

with mesh:
  delta = jax.jit(sim_fn)(cosmology)

@ASKabalan
Copy link
Collaborator

I can't think of a place where this happends in JaxPM
All functions that returns functions (make_ode for example) are jittable

All non-jittable functions return non-addressable arrays that has to pass as an arguments so it is lowered to ir.RankedTensorType and not ir.Constant
I think that making a non_jittable function that returns a jittable function like that is ok

As long as non-addressable arrays pass through the arguments
Closing on them is not allowed
Creating them within a jitted function is not allowed

@ASKabalan
Copy link
Collaborator

ASKabalan commented Jul 8, 2024

My idea was more like this

def generate_input(mesh_shape)
    kvec = fftk(mesh_shape) 
    initial_cond = generate_ic(mesh_shape) 

    return kvec, initial_cond

@jax.jit
def simulation(cosmo , kvec , initial_cond):
    solver = FastPM()

    state = solver.init_state(kvec, initial_cond,cosmo)

    state = solver.lpt(state)

    return solver.nbody(state)


# runs on 1 GPU
kvec, initial_cond = generate_input(mesh_shape)
final_state = simulation(cosmo, kvec, initial_cond)

# runs on multiple GPUs
with SPMDConfig(sharding):
    kvec, initial_cond = generate_input(mesh_shape)
    final_state = simulation(cosmo, kvec, initial_cond)

Fixed input (the one that you wont differentiale) is fixed .. so it will be generated once
simulation_fn can be called in a NUTS or HMC loop no problem

@EiffL
Copy link
Member Author

EiffL commented Jul 8, 2024

hum yeah, but ideally we don't want to expose for isntance kvec. It's an internal variable only useful to compute Fourier filtering operations, the user should never have to see it.

initial_cond is a slightly different situation, but still ideally it could be generated from within the jitted function.

Ok, let me try something, I'll make an update and push it here

@ASKabalan
Copy link
Collaborator

If we cannot expose kvec, I will have to rethink things.

The problem is it is interpolated with a distributed array.
I can make a hack where I create it with the __enter__ of the context mesh and send it to the functions using partial

I know that this is a hack, but I can't see a way around it

@aboucaud
Copy link

aboucaud commented Jul 8, 2024

I'll let you explore the sharding details while I try to deal with my vacation emails but FWIW I find the proposed demo code very clean and easy to follow. Good job !

@EiffL
Copy link
Member Author

EiffL commented Jul 8, 2024

Ok, here is my updated proposal :-)

I managed to keep the same structure and get rid of all references to mesh and sharding within the code, by accessing the mesh information from the context. I think this is kind of the hack you mentioned, but I think it's fair, here is the magic function:

def shmap(f: Callable,
          in_specs: Specs,
          out_specs: Specs,
          check_rep: bool = True,
          auto: frozenset[AxisName] = frozenset()):
  """Helper function to create a shard_map function that extracts the mesh from the
    context."""
  # Extracts the mesh from the context
  mesh = mesh_lib.thread_resources.env.physical_mesh
  return shard_map(f, mesh, in_specs, out_specs, check_rep, auto)

It uses the context to access the mesh. In a more evolved version of this, we can even enable or disable shard_map based on whether the user is running the code within a mesh context or not.

With this trick, I can jit compile the entire simulation function which becomes:

@partial(jax.jit, static_argnames=('nc', 'box_size', 'halo_size'))
def simulation_fn(key, nc, box_size, halo_size, a=1.0):

  # Build a default cosmology
  cosmology = jc.Planck15()

  # Create a small function to generate the linear matter power spectrum
  k = jnp.logspace(-4, 1, 128)
  pk = jc.power.linear_matter_power(cosmology, k)
  pk_fn = jax.jit(lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk).
                  reshape(x.shape))

  # Generate a Gaussian field and gravitational forces from a power spectrum
  intial_conditions, initial_forces = gaussian_field_and_forces(
      key=key, nc=nc, box_size=box_size, power_spectrum=pk_fn)

  # Compute the LPT displacement 
  initial_displacement = jc.background.growth_factor(cosmology, jnp.atleast_1d(a)) * initial_forces

  # Paints the displaced particles on a mesh to obtain the density field
  final_field = cic_paint(initial_displacement, halo_size)

  return intial_conditions, final_field

It can be called within the mesh context like so:

  # Create computing mesh and sharding information
  devices = mesh_utils.create_device_mesh(pdims)
  mesh = Mesh(devices, axis_names=('y', 'z'))

  with mesh:
    # Run the simulation on the compute mesh
    initial_conds, final_field = simulation_fn(
        key=key, nc=args.nc, box_size=args.box_size, halo_size=args.halo_size)

And the rest of the API has the following signature, which does not mention either mesh or sharding:

cic_paint(displacement, halo_size)
gaussian_field_and_forces(key, nc, box_size, power_spectrum)
gravity_kernel(kvec)
fttk(nc: int)

@EiffL
Copy link
Member Author

EiffL commented Jul 8, 2024

So I think it's pretty cool, we can make the user-level API completely transparent to the distribution information, and all that the user needs to do is to run their code within the with mesh context.

@EiffL EiffL marked this pull request as ready for review July 8, 2024 22:44
@EiffL EiffL requested a review from ASKabalan July 8, 2024 22:44
@ASKabalan
Copy link
Collaborator

Oh ok I see.

I think that this is a good way to do it.
However this forces the user to use 'z' and 'y' named axis
I did a different (more hacky thing) to have a JAX api between shardmap and custom partionning

]

@partial(
shmap,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I have not thought of using shardmap as a way to distribute arrays
Come to think of it we can even send just the mesh size as input with a replicate in_spec ( P(None) ) and get back the right sharding

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nah you need to have the input spec like this such that it will slice correctly the input vector for each dimension.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I mean that we can use a sharded fftk function that takes a mesh_shape and output a sharded array
This solves some many problems for me in jaxPM

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

examples/lpt_nbody_demo.py Show resolved Hide resolved
@EiffL
Copy link
Member Author

EiffL commented Jul 8, 2024

Well, the user does not have to know about 'z', 'y' more than knowing that they need to define a mesh with these names. Then the user never sees or touches a shard map themselves.

@ASKabalan
Copy link
Collaborator

No but if he names his axis ('first' 'second')
This will no longer work no?

@ASKabalan
Copy link
Collaborator

There is a way to get dynamically the names

https://github.com/google/jax/blob/main/jax/_src/mesh.py#L203

@EiffL
Copy link
Member Author

EiffL commented Jul 8, 2024

No it wont work anymore, but it's not unrealistic to expect the user to create a mesh with given axis names, we can also provide a utily function that would do it so that they dont need to name anything themselves.

@EiffL
Copy link
Member Author

EiffL commented Jul 8, 2024

ah yeah, why not.... we could use that in the shard map wrapper, but then we need to use axis indices instead of names, not sure that's easier.

@ASKabalan
Copy link
Collaborator

ASKabalan commented Jul 8, 2024

Yes.
I will show you how I solved this (with alot more boiler plate code)
For the axis indices it is not difficult, I already thought about it then decided to do someting different
And if we have for example 2x2 .. it doesn't matter if we switch "z" and "y"

Anyway, I think this falls more into a JaxPM discussion.
I propose to merge this PR and ask for a joss review.
We continue this for this branch ?

@EiffL
Copy link
Member Author

EiffL commented Jul 8, 2024

Yeah agreed, one small question for here, I still get the wrong result out. What is the shape of the FFT result if the input is of shape [x, y, z] ?
Did you add the extra transpose so that the output is always [x, y,z] ? or is the FFT output transposed?

@ASKabalan
Copy link
Collaborator

No it is transposed accordingly
Where are you getting the wrong results?
https://github.com/DifferentiableUniverseInitiative/jaxDecomp/blob/main/jaxdecomp/_src/fft.py#L96

The transpose is (1 2 0) for FFT (double shift left)
and (2 0 1) for IFFT (double shift right)

@ASKabalan
Copy link
Collaborator

ASKabalan commented Jul 8, 2024

FFT([X Y Z] is Z X Y (Z pencil)

But for frequencies you should use ky , kz , kx

@EiffL
Copy link
Member Author

EiffL commented Jul 8, 2024

let me try this

@EiffL
Copy link
Member Author

EiffL commented Jul 9, 2024

Ok, solved, the correct order of dimensions was kz, kx, ky.

For clarity I also renamed the dimensions to 'x','y' so that the following works:

  devices = mesh_utils.create_device_mesh(pdims)
  mesh = Mesh(devices.T, axis_names=('x', 'y'))

  # Create a distributed field drawn from a Gaussian distribution in real space
  delta = shmap(
      partial(jax.random.normal, shape=local_mesh_shape, dtype='float32'),
      in_specs=P(None),
      out_specs=P('x', 'y'))(key) 

This way it follows what a user would naively expect regarding what are the names of the dimensions, i.e. x, y, z.

Should be ready for another round of reviews @ASKabalan, I'm gonna take a look at the draft now.

@EiffL EiffL requested a review from ASKabalan July 9, 2024 01:27
]

@partial(
shmap,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I mean that we can use a sharded fftk function that takes a mesh_shape and output a sharded array
This solves some many problems for me in jaxPM

Comment on lines 78 to 85
kx, ky, kz = kvec
kk = kx**2 + ky**2 + kz**2
laplace_kernel = jnp.where(kk == 0, 1., 1. / kk)
grav_kernel = [
laplace_kernel * 1j / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)),
laplace_kernel * 1j / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)),
laplace_kernel * 1j / 6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz))
]
# Note that we return frequency arrays in the transposed order [z, x, y]
# corresponding to the transposed FFT output
grav_kernel = (laplace_kernel * 1j * kz,
laplace_kernel * 1j * kx,
laplace_kernel * 1j * ky) # yapf: disable
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be simpler to do and easier to read

kz ,kx, ky = kvec

and use the more intuitive ordering in the gradient kernel

mpirun -n 4 python lpt_nbody_demo.py --nc 256 --box_size 256 --pdims 4x4 --halo_size 32 --output out
```

We also include an example of a slurm script in [submit_rusty.sbatch](submit_rusty.sbatch) that can be used to run the example on a slurm cluster with:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't your sbatch specific for a cluster?
It won't work for JZ for example

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, but it gives an example, feel free to modify it and/or add a slurm script for jean zay

halo_periods=(True, True, True))

@partial(shmap, in_specs=(P('z', 'y'),), out_specs=P('z', 'y'))
@partial(shmap, in_specs=(P('x', 'y'),), out_specs=P('x', 'y'))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The axes were called Z and Y according to cudecomp
The fastest axis is called X and it is layed out like this ZYX (X pencil)
It is just a naming convention so not important

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but here I'm thinking about user expectations. A naive user would want to call the dimensions "x", "y", "z" in that order.

In general, it's always better to design the code such that it aligns with what a user that has not read the documentation would expect and avoid user surprise.

So, even if this goes against the nomenclature in cuDecomp, I do think it's a better choice for the end-user.

@ASKabalan
Copy link
Collaborator

I am gonna merge this then to my branch

@EiffL ok?

@EiffL
Copy link
Member Author

EiffL commented Jul 9, 2024

I changed the way I'm handling the kx,ky,kz following your comments. I think it makes reasonable sense now.

I'm happy with this, we can merge after your approve.

@ASKabalan ASKabalan merged commit dbac275 into joss-paper Jul 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants