Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simple pallas kernel hangs when input size exceeds some threshold #21282

Open
zhixuan-lin opened this issue May 17, 2024 · 2 comments
Open

simple pallas kernel hangs when input size exceeds some threshold #21282

zhixuan-lin opened this issue May 17, 2024 · 2 comments
Assignees
Labels
bug Something isn't working pallas Issues pertaining to Pallas (GPU or TPU)

Comments

@zhixuan-lin
Copy link

zhixuan-lin commented May 17, 2024

Description

The following simple pallas kernel that copies an array hangs indefinitely:

import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl

def copy_kernel(
    src,
    dst
):
    def body(i, carry):
        dst[i] = src[i]
        return carry

    _ = jax.lax.fori_loop(
        lower=0,
        upper=src.shape[0],
        body_fun=body,
        init_val=None
    )

@jax.jit
@jax.vmap
def copy_func(src):

    func = pl.pallas_call(
        f=copy_kernel,
        out_shape=jax.ShapeDtypeStruct(src.shape, src.dtype)
    )

    dst = func(src)
    return dst


if __name__ == '__main__':
    batch_size = 2 ** 16
    seq_length = 2 ** 16
    dtype = jnp.float32
    # dtype = jnp.bfloat16
    src = jnp.zeros((batch_size, seq_length), dtype=dtype)
    print(f'Array elements: {src.size}')
    print(f'Array size: {src.nbytes / 1e9:.4f}GB')
    dst = copy_func(src)
    dst.block_until_ready()

Program output:

2024-05-17 10:18:51.191896: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
Array elements: 4294967296
Array size: 17.1799GB

I remember reading somewhere that the above warning can be ignored so I think it this is unlikely related to the issue I'm seeing.

It looks like as long as batch_size * seq_length <= 2 ** 31 then the program will not get stuck. For example, if I change either batch size or seq_length from 2 ** 16 to 2 ** 15 then it works fine. However, changing dtype from float32 to bfloat16 does not fix the problem. Plus I'm using A100 80GB, with batch_size = seq_length = 2 ** 16, dtype=float32 the array only takes roughly 17GB. So it perhaps has nothing to do with memory.

Also when it hangs both GPU and CPU utilization is zero.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.28
jaxlib: 0.4.28
numpy:  1.26.4
python: 3.10.4 (main, Mar 31 2022, 08:41:55) [GCC 7.5.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='cn-g020.server.mila.quebec', release='5.15.0-101-generic', version='#111-Ubuntu SMP Tue Mar 5 20:16:58 UTC 2024', machine='x86_64')


$ nvidia-smi
Fri May 17 10:13:08 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.08             Driver Version: 535.161.08   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100-SXM4-80GB          On  | 00000000:41:00.0 Off |                    0 |
| N/A   25C    P0              72W / 500W |    424MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A   3144259      C   python                                      416MiB |
+---------------------------------------------------------------------------------------+

@zhixuan-lin zhixuan-lin added the bug Something isn't working label May 17, 2024
@zhixuan-lin
Copy link
Author

I also tested jax==0.4.25 with cuda 11 with which I don't see the ptxas version warning (but the kernel still hangs indefinitely), so it likely has nothing to do with that

@superbobry superbobry added the pallas Issues pertaining to Pallas (GPU or TPU) label May 17, 2024
@superbobry
Copy link
Member

It looks like something overflows and the loop iterates forever, but I'm not sure where the overflow actually happens.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working pallas Issues pertaining to Pallas (GPU or TPU)
Projects
None yet
Development

No branches or pull requests

4 participants