-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open: optimize for GEMM regime #2
Comments
Hi, thanks for your trying out the kernel! Yes, Marlin is indeed currently the most optimized for the medium batchsize range; in fact, for batchsizes > 64, it will just run the batch 64 kernel repeatedly. In this medium <= 64 batchsize range, partitioning the problem enough to saturate the GPU is actually quite challenging and Marlin employs some specific techniques to work around this (e.g., have multiple warps accumulate the same output tile, balanced striped partioning, etc.). In contrast, at larger GEMM size, these tricks are not needed and are even a bit suboptimal (hence the observed slowdown). That being said, your benchmark is also a relatively challenging case (for lower batchsizes as well), I would expect slowdowns to be a bit lower on larger weight matrices and/or on a weaker GPU than the A100. For example, on an A10 I get for your shape:
or with locked clocks:
The good news is that I think Marlin's very large batch performance can likely be improved noticeably by implementing a better partitioning scheme for larger problems, which although maybe non-trivial, shouldn't require a massive change. I will think about what the easiest way would be to do that. One thing that would be useful in that regard is to know what the most important problem shapes are to optimize for. |
Hi, so I just merged Benchmarking on the A100 on the same 8kx8k matrix of @fxmarty as in the thread above (using
a major improvement over the prior version:
The results on Llama shapes are very similar and on weaker GPUs, like the A10, it's even closer to FP16 at extremely large batchsizes (it's around 13% faster than what I posted above). Overall, it seems that Marlin should now perform very similar to FP16 up to batchsizes ~1k in many cases. While there are still a few more specific optimizations for very large batchsizes that could be done, I think from 2k onwards, the most reliable solution is probably to just decompress and then run the FP16 kernel, which is so expensive at this point that decompression overhead should be pretty much negligible. I believe some work on that is already being done here by @rib-2? Otherwise, I may look into it myself at some point. |
Thank you @efrantar! Here's an updated benchmark on 3aa5a05 & turboderp/exllamav2@0e9d9c1: Really cool! Unfortunately I can't lock the clock on the GPU I have.
Could be. One thing is also that exllamav2 preallocates a buffer for dequantization & computation through cublas in fp16*fp16 for large shapes: https://github.com/turboderp/exllamav2/blob/0e9d9c10101f8faaa69647bec2a517aa4e06f715/exllamav2/exllamav2_ext/cuda/q_gemm.cu#L123-L132. But still, dequantization + cublas call should be slower? cc @turboderp is this weird result?
Yes, that's what the exllama kernel does. I'm just not sure how interesting it is in terms of memory since it allocates a buffer for the fp16 decompressed weight. Would be fun to compare to https://github.com/NVIDIA/cutlass/releases/tag/v3.4.0 as well |
The dequantization pass is quite fast, completely memory-bound so about the same latency as matmul at q_len 1. And the extra global memory access is going to be offset by the subsequent cuBLAS matmul having most or all of those FP16 weights ready in the L2 cache. Overall, it's a pretty small overhead on top of cuBLAS which I expect would be hard to even measure for m > 2000 or so. Needing the extra buffer is a sad tradeoff, but at least you can preallocate it from a pool shared between all the linear layers on the same device, so it's only like 40 MB in total for a 7B model. As for why the blue and green lines don't converge, I can only imagine the green line would be Torch matmul, which is apparently just a bit slower than cuBLAS? |
Ohoh, after investigating a bit, PyTorch picks e.g. I wonder if this could not be the cause of some of the nan issues we've had. @efrantar It seems marlin uses fp32 accumulation as we have marlin/marlin/marlin_cuda_kernel.cu Lines 92 to 98 in 3aa5a05
|
Interesting; yes, Marlin accumulates in fp32. While the tensor cores should have the same speed for fp32 and fp16 accumulation, the latter allows larger tiles (since the accumulators require less registers) which could indeed be a bit faster for very large shapes. I am not sure if fp16 accumulation is always fully safe, hence I went for FP32 with Marlin. That being said, how exactly are you benchmarking? I have found benchmarking with unlocked clocks to be pretty messy (especially once operations become more expensive and may be throttled), e.g. results could differ significantly depending on order, number of repetitions and wait times (I would really like to know a good solution for this if there is one). In my A100 setup above at batchsize 8k, torch takes ~4.3ms which seems to essentially match your exllamav2 and awq results, so maybe there is just something weird with the benchmarks? |
@efrantar I use something along the lines of times = []
for _ in range(num_runs):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
res = forward_handle(inp)
end_event.record()
torch.cuda.synchronize()
tps = (start_event.elapsed_time(end_event))
times.append(tps)
mean_latency = np.mean(times) Unfortunately not on a clock-locked GPU. Should do that indeed. (Edit: simply trying |
Hi @efrantar, thanks a lot for sharing this optimized kernel!
I've given a try to them on A100, and the results are really nice for the regime with short sequences (<30-50). So so far for autoregressive decoding marlin would be the best in class kernel available as it maintains good latency even for batch sizes larger than one.
However, for longer sequence lengths (or large batch size, does not matter), both fp16xfp16 and exllamav2 outperform marlin kernel. AFAIK, exllamav2 just unpacks weights to fp16 in this case and calls cublas, I'm not sure why it is faster than pytorch native fp16x16.
It seems optimizing int4xfp16 kernels in GEMM regime is still open. So far we have different kernels best performing for some shapes, but they all require different packing, etc. which is not very handy.
One could still argue it is worth trading off more latency in the prefill to have lower one in the decoding, where usually we spend more time.
The text was updated successfully, but these errors were encountered: