Skip to content

Bad XLA codegen for exact TopK #886

@x66ccff

Description

@x66ccff

Hi, I'm not sure if I'm using Reactant correctly. I've found that in my code, the current performance bottleneck seems to be the operation of selecting certain positions from a huge tensor to the CPU. Specifically:

val_R, idx_R = model.top_k_compiled(-mean_errors_R, model.PSRN_topk)

@info "idx_R:"
@time @info idx_R

I find that printing idx_R here is very slow. I understand this slowness may be due to the synchronization needed from GPU to CPU. But I can't understand why this process is so much slower than PyTorch? 🤔

Julia Implementation with Reactant (~1.8 seconds)

using Reactant
n = 1_000_000_000  # 1 billion elements
k = 100
x = rand(Float32, n)
xr = Reactant.to_rarray(x)
partialsortperm_compiled = @compile partialsortperm(xr, 1:k)

GC.gc()

# First benchmark: compute but don't access the result data
@time begin
    res = partialsortperm_compiled(xr, 1:k)
    @info 1
end
[ Info: 1
  0.870469 seconds (1.86 M allocations: 93.889 MiB, 99.52% compilation time)
GC.gc()

# Second benchmark: compute and access the result data (forcing GPU->CPU transfer)
@time begin
    res = partialsortperm_compiled(xr, 1:k)
    @info res
end
[ Info: SubArray{Int64, 1, ConcretePJRTArray{Int64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{UnitRange{Int64}}, true}([955852229, 932404690, 906993663, 899721285, 894668984, 877960574, 871792407, 796790415, 792612165, ..337289477, 332445067, 267677250, 222104083, 214054217, 194744078, 193590981, 193251762, 185722187, 114722941])
 16.426357 seconds (93.05 M allocations: 4.559 GiB, 3.10% gc time, 99.32% compilation time)

When running this a second time after compilation, the timing becomes much better (around 1.8 seconds), but it's still significantly slower than the PyTorch equivalent.

Python Implementation with PyTorch for Comparison (0.001586 seconds)

For comparison, I ran the equivalent operation in Python with PyTorch, and it's significantly faster:

import torch
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x = torch.rand(1_000_000_000, dtype=torch.float32, device=device)  # 1 billion elements

start_time = time.time()
indices = torch.topk(-x, 100)[1]
print(f"Time: {time.time() - start_time:.6f}s, {indices.cpu().tolist()}")

Time: 0.001586s

As you can see, the Python/PyTorch implementation is about 1000× faster for the same operation, even when explicitly calling .cpu().tolist() to transfer data back to CPU.

Is this a known issue with Reactant's GPU-to-CPU data transfer? Even after accounting for Julia's compilation time, the actual data transfer seems much slower than PyTorch. Are there any recommended workarounds or optimizations I could implement to improve this performance bottleneck?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions