-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Proposal for simplified demo #24
Conversation
I fixed the issue in your example. 1 - You used NamedSharding as an argument .. NamedSharding cannot be traced I think in the latest version of JAX, the JAX team allows doing non jitted operations on non-addressable arrays inside a context mesh (basicaly Since this is just an example, I think it is ok. For JaxPM, I solved these issues. |
Also, in my opinion your version of |
Co-authored-by: Wassim KABALAN <wastondev@gmail.com>
Cool, thanks for the fixes! Yeah.... I know the limitations of this approach... but I'd like to keep things in terms of functions.... What would you think of a mechanism like what we do in jaxpm, where we have a non-jitted function that returns a function that can be jitted? Like: sim_fn = make_simulation_fn(mesh_shape,...., sharding)
with mesh:
delta = jax.jit(sim_fn)(cosmology) |
I can't think of a place where this happends in JaxPM All non-jittable functions return non-addressable arrays that has to pass as an arguments so it is lowered to As long as non-addressable arrays pass through the arguments |
My idea was more like this def generate_input(mesh_shape)
kvec = fftk(mesh_shape)
initial_cond = generate_ic(mesh_shape)
return kvec, initial_cond
@jax.jit
def simulation(cosmo , kvec , initial_cond):
solver = FastPM()
state = solver.init_state(kvec, initial_cond,cosmo)
state = solver.lpt(state)
return solver.nbody(state)
# runs on 1 GPU
kvec, initial_cond = generate_input(mesh_shape)
final_state = simulation(cosmo, kvec, initial_cond)
# runs on multiple GPUs
with SPMDConfig(sharding):
kvec, initial_cond = generate_input(mesh_shape)
final_state = simulation(cosmo, kvec, initial_cond)
Fixed input (the one that you wont differentiale) is fixed .. so it will be generated once |
hum yeah, but ideally we don't want to expose for isntance
Ok, let me try something, I'll make an update and push it here |
If we cannot expose kvec, I will have to rethink things. The problem is it is interpolated with a distributed array. I know that this is a hack, but I can't see a way around it |
I'll let you explore the sharding details while I try to deal with my vacation emails but FWIW I find the proposed demo code very clean and easy to follow. Good job ! |
Ok, here is my updated proposal :-) I managed to keep the same structure and get rid of all references to mesh and sharding within the code, by accessing the mesh information from the context. I think this is kind of the hack you mentioned, but I think it's fair, here is the magic function: def shmap(f: Callable,
in_specs: Specs,
out_specs: Specs,
check_rep: bool = True,
auto: frozenset[AxisName] = frozenset()):
"""Helper function to create a shard_map function that extracts the mesh from the
context."""
# Extracts the mesh from the context
mesh = mesh_lib.thread_resources.env.physical_mesh
return shard_map(f, mesh, in_specs, out_specs, check_rep, auto) It uses the context to access the mesh. In a more evolved version of this, we can even enable or disable shard_map based on whether the user is running the code within a mesh context or not. With this trick, I can jit compile the entire simulation function which becomes: @partial(jax.jit, static_argnames=('nc', 'box_size', 'halo_size'))
def simulation_fn(key, nc, box_size, halo_size, a=1.0):
# Build a default cosmology
cosmology = jc.Planck15()
# Create a small function to generate the linear matter power spectrum
k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power(cosmology, k)
pk_fn = jax.jit(lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk).
reshape(x.shape))
# Generate a Gaussian field and gravitational forces from a power spectrum
intial_conditions, initial_forces = gaussian_field_and_forces(
key=key, nc=nc, box_size=box_size, power_spectrum=pk_fn)
# Compute the LPT displacement
initial_displacement = jc.background.growth_factor(cosmology, jnp.atleast_1d(a)) * initial_forces
# Paints the displaced particles on a mesh to obtain the density field
final_field = cic_paint(initial_displacement, halo_size)
return intial_conditions, final_field It can be called within the mesh context like so: # Create computing mesh and sharding information
devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices, axis_names=('y', 'z'))
with mesh:
# Run the simulation on the compute mesh
initial_conds, final_field = simulation_fn(
key=key, nc=args.nc, box_size=args.box_size, halo_size=args.halo_size) And the rest of the API has the following signature, which does not mention either mesh or sharding: cic_paint(displacement, halo_size)
gaussian_field_and_forces(key, nc, box_size, power_spectrum)
gravity_kernel(kvec)
fttk(nc: int) |
So I think it's pretty cool, we can make the user-level API completely transparent to the distribution information, and all that the user needs to do is to run their code within the |
Oh ok I see. I think that this is a good way to do it. |
] | ||
|
||
@partial( | ||
shmap, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK I have not thought of using shardmap as a way to distribute arrays
Come to think of it we can even send just the mesh size as input with a replicate in_spec ( P(None) ) and get back the right sharding
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nah you need to have the input spec like this such that it will slice correctly the input vector for each dimension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I mean that we can use a sharded fftk function that takes a mesh_shape and output a sharded array
This solves some many problems for me in jaxPM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure
Well, the user does not have to know about 'z', 'y' more than knowing that they need to define a mesh with these names. Then the user never sees or touches a shard map themselves. |
No but if he names his axis ('first' 'second') |
There is a way to get dynamically the names https://github.com/google/jax/blob/main/jax/_src/mesh.py#L203 |
No it wont work anymore, but it's not unrealistic to expect the user to create a mesh with given axis names, we can also provide a utily function that would do it so that they dont need to name anything themselves. |
ah yeah, why not.... we could use that in the shard map wrapper, but then we need to use axis indices instead of names, not sure that's easier. |
Yes. Anyway, I think this falls more into a JaxPM discussion. |
Yeah agreed, one small question for here, I still get the wrong result out. What is the shape of the FFT result if the input is of shape [x, y, z] ? |
No it is transposed accordingly The transpose is (1 2 0) for FFT (double shift left) |
FFT([X Y Z] is Z X Y (Z pencil) But for frequencies you should use |
let me try this |
Ok, solved, the correct order of dimensions was kz, kx, ky. For clarity I also renamed the dimensions to 'x','y' so that the following works: devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices.T, axis_names=('x', 'y'))
# Create a distributed field drawn from a Gaussian distribution in real space
delta = shmap(
partial(jax.random.normal, shape=local_mesh_shape, dtype='float32'),
in_specs=P(None),
out_specs=P('x', 'y'))(key) This way it follows what a user would naively expect regarding what are the names of the dimensions, i.e. x, y, z. Should be ready for another round of reviews @ASKabalan, I'm gonna take a look at the draft now. |
] | ||
|
||
@partial( | ||
shmap, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I mean that we can use a sharded fftk function that takes a mesh_shape and output a sharded array
This solves some many problems for me in jaxPM
examples/lpt_nbody_demo.py
Outdated
kx, ky, kz = kvec | ||
kk = kx**2 + ky**2 + kz**2 | ||
laplace_kernel = jnp.where(kk == 0, 1., 1. / kk) | ||
grav_kernel = [ | ||
laplace_kernel * 1j / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)), | ||
laplace_kernel * 1j / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)), | ||
laplace_kernel * 1j / 6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)) | ||
] | ||
# Note that we return frequency arrays in the transposed order [z, x, y] | ||
# corresponding to the transposed FFT output | ||
grav_kernel = (laplace_kernel * 1j * kz, | ||
laplace_kernel * 1j * kx, | ||
laplace_kernel * 1j * ky) # yapf: disable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be simpler to do and easier to read
kz ,kx, ky = kvec
and use the more intuitive ordering in the gradient kernel
mpirun -n 4 python lpt_nbody_demo.py --nc 256 --box_size 256 --pdims 4x4 --halo_size 32 --output out | ||
``` | ||
|
||
We also include an example of a slurm script in [submit_rusty.sbatch](submit_rusty.sbatch) that can be used to run the example on a slurm cluster with: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isn't your sbatch specific for a cluster?
It won't work for JZ for example
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, but it gives an example, feel free to modify it and/or add a slurm script for jean zay
halo_periods=(True, True, True)) | ||
|
||
@partial(shmap, in_specs=(P('z', 'y'),), out_specs=P('z', 'y')) | ||
@partial(shmap, in_specs=(P('x', 'y'),), out_specs=P('x', 'y')) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The axes were called Z and Y according to cudecomp
The fastest axis is called X and it is layed out like this ZYX (X pencil)
It is just a naming convention so not important
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but here I'm thinking about user expectations. A naive user would want to call the dimensions "x", "y", "z" in that order.
In general, it's always better to design the code such that it aligns with what a user that has not read the documentation would expect and avoid user surprise.
So, even if this goes against the nomenclature in cuDecomp, I do think it's a better choice for the end-user.
I am gonna merge this then to my branch @EiffL ok? |
I changed the way I'm handling the kx,ky,kz following your comments. I think it makes reasonable sense now. I'm happy with this, we can merge after your approve. |
I'd like to propose a simplification of the demo script, I removed a bunch of stuff that wasn't striclty necessary and tried to reorganize it around logical functions.
So for instance, this:
becomes:
and then the main simulation function becomes:
I'd be happy to streamline it even further, to make it as simple as possible, in particular around the cic_paint function, which is not easy to understand. And also, this recoded example doesn't actually give a good result, I'm not sure why.
But take a look and let me know what you think.