-
Notifications
You must be signed in to change notification settings - Fork 11
/
test_pfft.py
96 lines (81 loc) · 2.33 KB
/
test_pfft.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# Can be executed with:
# srun -n 4 -c 32 --gpus-per-task 1 --gpu-bind=none python test_pfft.py
from functools import partial
import jax
import jax.lax as lax
import jax.numpy as jnp
import numpy as np
from jax.experimental.maps import Mesh, xmap
from jax.experimental.pjit import PartitionSpec, pjit
jax.distributed.initialize()
cube_size = 2048
@partial(xmap,
in_axes=[...],
out_axes=['x', 'y', ...],
axis_sizes={
'x': cube_size,
'y': cube_size
},
axis_resources={
'x': 'nx',
'y': 'ny',
'key_x': 'nx',
'key_y': 'ny'
})
def pnormal(key):
return jax.random.normal(key, shape=[cube_size])
@partial(xmap,
in_axes={
0: 'x',
1: 'y'
},
out_axes=['x', 'y', ...],
axis_resources={
'x': 'nx',
'y': 'ny'
})
@jax.jit
def pfft3d(mesh):
# [x, y, z]
mesh = jnp.fft.fft(mesh) # Transform on z
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now x is exposed, [z,y,x]
mesh = jnp.fft.fft(mesh) # Transform on x
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now y is exposed, [z,x,y]
mesh = jnp.fft.fft(mesh) # Transform on y
# [z, x, y]
return mesh
@partial(xmap,
in_axes={
0: 'x',
1: 'y'
},
out_axes=['x', 'y', ...],
axis_resources={
'x': 'nx',
'y': 'ny'
})
@jax.jit
def pifft3d(mesh):
# [z, x, y]
mesh = jnp.fft.ifft(mesh) # Transform on y
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now x is exposed, [z,y,x]
mesh = jnp.fft.ifft(mesh) # Transform on x
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now z is exposed, [x,y,z]
mesh = jnp.fft.ifft(mesh) # Transform on z
# [x, y, z]
return mesh
key = jax.random.PRNGKey(42)
# keys = jax.random.split(key, 4).reshape((2,2,2))
# We reshape all our devices to the mesh shape we want
devices = np.array(jax.devices()).reshape((2, 4))
with Mesh(devices, ('nx', 'ny')):
mesh = pnormal(key)
kmesh = pfft3d(mesh)
kmesh.block_until_ready()
# jax.profiler.start_trace("tensorboard")
# with Mesh(devices, ('nx', 'ny')):
# mesh = pnormal(key)
# kmesh = pfft3d(mesh)
# kmesh.block_until_ready()
# jax.profiler.stop_trace()
print('Done')