Skip to content

Runtime of kernel seems dependant on input data distribution. #15

@vvvm23

Description

@vvvm23

I've been trying to integrate use this kernel to compute loss/grads with large activations and vocabularies. I noticed significant changes in the runtime of the kernel based on the distribution of the input data.

This sounds a bit insane, but I have a small reproducer that shows this behaviour:

import cut_cross_entropy as cce_lib
import torch
import time

device = torch.device('cuda:0')

def benchmark_fn(embeddings, vocab, labels):
    # clone, set requires_grad, and cast to device (though, last should be not needed)
    embeddings, vocab = torch.clone(embeddings).requires_grad_(True).to(device), torch.clone(vocab).requires_grad_(True).to(device)
    labels = torch.clone(labels).to(device)

    torch.cuda.synchronize()
    start_time = time.time()
    loss = cce_lib.linear_cross_entropy(embeddings, vocab, labels)
    loss.backward()
    torch.cuda.synchronize()
    print(time.time() - start_time)

# generatee inputs
embeddings = torch.randn(4*8192, 4096, device=device, dtype=torch.bfloat16)
vocab = torch.randn(256_000, 4096, device=device, dtype=torch.bfloat16)
labels = torch.randint(0, 256_000, (4*8192,), device=device)

# compile call, exclude from loops
print("first call")
benchmark_fn(embeddings, vocab, labels)

print("regular inputs")
for _ in range(5):
   benchmark_fn(embeddings, vocab, labels)

print("scaled down weights result in slowdown!")
vocab = vocab * 1/8
for _ in range(5):
    benchmark_fn(embeddings, vocab, labels)
vocab = vocab * 8

print("scaled up weights do not.")
vocab = vocab * 8
for _ in range(5):
    benchmark_fn(embeddings, vocab, labels)
vocab = vocab*1/8

print("scaled down activations also cause slowdown")
embeddings = embeddings * 1/8
for _ in range(5):
    benchmark_fn(embeddings, vocab, labels)
embeddings = embeddings*8

print("but not scaled up")
embeddings = embeddings * 8
for _ in range(5):
    benchmark_fn(embeddings, vocab, labels)
embeddings = embeddings*1/8

which will output:

first call
1.8656930923461914
regular inputs
0.4381716251373291
0.4551575183868408
0.4390103816986084
0.44695091247558594
0.44449806213378906
scaled down weights result in slowdown!
1.2427277565002441
1.2380731105804443
1.2460360527038574
1.2462165355682373
1.245213270187378
scaled up weights do not.
0.4104018211364746
0.41744208335876465
0.42533063888549805
0.41300249099731445
0.4272499084472656
scaled down activations also cause slowdown
1.2431302070617676
1.2402148246765137
1.2415409088134766
1.2405052185058594
1.2423794269561768
but not scaled up
0.41109371185302734
0.41783595085144043
0.42434120178222656
0.4148707389831543
0.42697978019714355

By simply scaling down either the weights or activations by a factor of 8, we get a completely different runtime.

I am kinda at a loss (hah) to understand why this could happen! I have checked by passing additional debug flags that no additional recompilations are happening, so this is a single compiled kernel with this behaviour.

Relevant parts of my env:

PyTorch version: 2.5.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 18.1.8 (++20240731024944+3b5b5c1ec4a3-1~exp1~20240731145000.144)
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.10.11 (main, Nov 17 2024, 19:27:51) [GCC 11.4.0] (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 12.6.68
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA H100 80GB HBM3

Nvidia driver version: 535.216.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.5.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==2.2.0
[pip3] torch==2.5.1
[pip3] triton==3.1.0
[conda] Could not collect

Thank you for your assistance 🙏

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions