[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DifferentiableUniverseInitiative/jaxDecomp/blob/main/examples/tpu_fft_demo.ipynb)

# jaxDecomp: Distributed 3D FFT on TPUs

This notebook demonstrates `pfft3d` and `pifft3d` (distributed 3D FFTs) running on a **TPU runtime** in Google Colab.

Unlike the CPU-based `demo_features.ipynb` which simulates multiple devices, this notebook uses **real TPU cores** for truly distributed computation.

> **Setup:** You must select a TPU runtime before running this notebook.  
> Go to **Runtime > Change runtime type > TPU**.

In [None]:
!pip install -q jaxdecomp

import jax

In [None]:
import jax.numpy as jnp
from jax.sharding import AxisType, NamedSharding
from jax.sharding import PartitionSpec as P

import jaxdecomp as jdp

devices = jax.devices()
print(f'Backend:      {jax.default_backend()}')
print(f'Device count: {jax.device_count()}')
print(f'Devices:      {devices}')

assert jax.device_count() > 1, 'Only 1 device found. Make sure you selected a TPU runtime: ' 'Runtime > Change runtime type > TPU'

In [None]:
# Compute pdims from device count
# For 8 TPU cores: (4, 2) pencil decomposition
n_devices = jax.device_count()

# Find a balanced 2D factorization
import math

py = int(math.isqrt(n_devices))
while n_devices % py != 0:
    py -= 1
pz = n_devices // py
pdims = (pz, py)

mesh = jax.make_mesh(
    pdims,
    ('z', 'y'),
    axis_types=(AxisType.Auto, AxisType.Auto),
)

print(f'Device count: {n_devices}')
print(f'Mesh pdims:   {pdims}')
print(f'Mesh shape:   {mesh.shape}')

## Distributed 3D FFT

`pfft3d` computes a distributed 3D FFT. It performs local FFTs on each TPU core and uses `all_to_all` transposes to redistribute data between axes.

**Important:** The output shape is **transposed** relative to the input: `(X, Y, Z) -> (Y, Z, X)`. This is because the data must be rearranged across devices during the FFT.

`pifft3d` computes the inverse, restoring the original shape and layout.

In [None]:
global_shape = (512, 512, 512)

sharding = NamedSharding(mesh, P('z', 'y'))
key = jax.random.PRNGKey(42)
x = jax.device_put(jax.random.normal(key, global_shape), sharding)

# Forward FFT
k = jdp.pfft3d(x)

print(f'Input shape:  {x.shape}')
print(f'Output shape: {k.shape}  (transposed: X,Y,Z -> Y,Z,X)')
print(f'Output dtype: {k.dtype}')

print('\nInput sharding:')
jax.debug.visualize_array_sharding(x[..., 0])

print('\nFFT output sharding:')
jax.debug.visualize_array_sharding(k[..., 0])

In [None]:
# Inverse FFT: recover the original array
x_rec = jdp.pifft3d(k)

print(f'Reconstructed shape: {x_rec.shape}  (matches original: {x_rec.shape == x.shape})')

# Verify round-trip correctness
is_close = jnp.allclose(x, x_rec.real, atol=1e-5)
print(f'Round-trip FFT -> IFFT allclose: {is_close}')

## Timing

JAX uses JIT compilation, so the **first call** includes compilation overhead. Subsequent calls execute the cached compiled program and are much faster.

In [None]:
import time

pfft3d_jit = jax.jit(jdp.pfft3d)
pifft3d_jit = jax.jit(jdp.pifft3d)

# Warm up (includes compilation)
k_warmup = pfft3d_jit(x).block_until_ready()
_ = pifft3d_jit(k_warmup).block_until_ready()

# Benchmark forward FFT
n_iters = 10
start = time.perf_counter()
for _ in range(n_iters):
    k_out = pfft3d_jit(x).block_until_ready()
elapsed_fwd = (time.perf_counter() - start) / n_iters

# Benchmark inverse FFT
start = time.perf_counter()
for _ in range(n_iters):
    x_out = pifft3d_jit(k_out).block_until_ready()
elapsed_inv = (time.perf_counter() - start) / n_iters

print(f'Array shape: {global_shape}')
print(f'Mesh pdims:  {pdims}')
print(f'pfft3d:  {elapsed_fwd*1000:.1f} ms / call  ({n_iters} iterations)')
print(f'pifft3d: {elapsed_inv*1000:.1f} ms / call  ({n_iters} iterations)')