Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open: optimize for GEMM regime #2

Closed
fxmarty opened this issue Jan 18, 2024 · 7 comments
Closed

Open: optimize for GEMM regime #2

fxmarty opened this issue Jan 18, 2024 · 7 comments

Comments

@fxmarty
Copy link
Contributor

fxmarty commented Jan 18, 2024

Hi @efrantar, thanks a lot for sharing this optimized kernel!

I've given a try to them on A100, and the results are really nice for the regime with short sequences (<30-50). So so far for autoregressive decoding marlin would be the best in class kernel available as it maintains good latency even for batch sizes larger than one.

image

However, for longer sequence lengths (or large batch size, does not matter), both fp16xfp16 and exllamav2 outperform marlin kernel. AFAIK, exllamav2 just unpacks weights to fp16 in this case and calls cublas, I'm not sure why it is faster than pytorch native fp16x16.

image

It seems optimizing int4xfp16 kernels in GEMM regime is still open. So far we have different kernels best performing for some shapes, but they all require different packing, etc. which is not very handy.

One could still argue it is worth trading off more latency in the prefill to have lower one in the decoding, where usually we spend more time.

@efrantar
Copy link
Member

Hi, thanks for your trying out the kernel!

Yes, Marlin is indeed currently the most optimized for the medium batchsize range; in fact, for batchsizes > 64, it will just run the batch 64 kernel repeatedly.

In this medium <= 64 batchsize range, partitioning the problem enough to saturate the GPU is actually quite challenging and Marlin employs some specific techniques to work around this (e.g., have multiple warps accumulate the same output tile, balanced striped partioning, etc.). In contrast, at larger GEMM size, these tricks are not needed and are even a bit suboptimal (hence the observed slowdown).

That being said, your benchmark is also a relatively challenging case (for lower batchsizes as well), I would expect slowdowns to be a bit lower on larger weight matrices and/or on a weaker GPU than the A100. For example, on an A10 I get for your shape:

batch=1024: s=0.00147, TFLOP/s=093.778, GB/s=0046.505, speedup=1.15
batch=2048: s=0.00293, TFLOP/s=093.796, GB/s=0034.707, speedup=1.09
batch=4096: s=0.00660, TFLOP/s=083.306, GB/s=0025.582, speedup=1.17
batch=8192: s=0.01429, TFLOP/s=076.964, GB/s=0021.212, speedup=1.03

or with locked clocks:

batch=1024: s=0.00264, TFLOP/s=052.021, GB/s=0025.798, speedup=1.00
batch=2048: s=0.00527, TFLOP/s=052.208, GB/s=0019.318, speedup=0.87
batch=4096: s=0.01055, TFLOP/s=052.105, GB/s=0016.000, speedup=0.84
batch=8192: s=0.02106, TFLOP/s=052.206, GB/s=0014.389, speedup=0.83

The good news is that I think Marlin's very large batch performance can likely be improved noticeably by implementing a better partitioning scheme for larger problems, which although maybe non-trivial, shouldn't require a massive change. I will think about what the easiest way would be to do that.

One thing that would be useful in that regard is to know what the most important problem shapes are to optimize for.

@efrantar
Copy link
Member

Hi,

so I just merged v.0.1.1 which essentially allows the kernel to solve multiple batchsize=64 problems in parallel with a single launch, thus allowing better partitioning and consequently lower reduction overheads; this should mostly resolve the primary inefficiency Marlin previously faced with larger batchsizes (in particular, if the weight matrix isn't huge).

Benchmarking on the A100 on the same 8kx8k matrix of @fxmarty as in the thread above (using bench.py with warmup=100 and iters=1000 to stabilize clocks as we can otherwise get very weird results on such expensive gemms with temporary boosts; perhaps exllama2 outperforming torch is related to this?), I now get:

batch=1024: s=0.00061, TFLOP/s=225.621, GB/s=0111.888, speedup=0.99
batch=2048: s=0.00122, TFLOP/s=225.192, GB/s=0083.327, speedup=0.91
batch=4096: s=0.00244, TFLOP/s=225.140, GB/s=0069.137, speedup=0.91
batch=8192: s=0.00489, TFLOP/s=225.012, GB/s=0062.016, speedup=0.88

a major improvement over the prior version:

batch=1024: s=0.00085, TFLOP/s=162.529, GB/s=0080.600, speedup=0.71
batch=2048: s=0.00170, TFLOP/s=161.578, GB/s=0059.788, speedup=0.65
batch=4096: s=0.00338, TFLOP/s=162.482, GB/s=0049.896, speedup=0.66
batch=8192: s=0.00680, TFLOP/s=161.678, GB/s=0044.560, speedup=0.63

The results on Llama shapes are very similar and on weaker GPUs, like the A10, it's even closer to FP16 at extremely large batchsizes (it's around 13% faster than what I posted above). Overall, it seems that Marlin should now perform very similar to FP16 up to batchsizes ~1k in many cases. While there are still a few more specific optimizations for very large batchsizes that could be done, I think from 2k onwards, the most reliable solution is probably to just decompress and then run the FP16 kernel, which is so expensive at this point that decompression overhead should be pretty much negligible. I believe some work on that is already being done here by @rib-2? Otherwise, I may look into it myself at some point.

@fxmarty
Copy link
Contributor Author

fxmarty commented Feb 2, 2024

Thank you @efrantar! Here's an updated benchmark on 3aa5a05 & turboderp/exllamav2@0e9d9c1:

image
image

Really cool! Unfortunately I can't lock the clock on the GPU I have.

perhaps exllama2 outperforming torch is related to this?

Could be. One thing is also that exllamav2 preallocates a buffer for dequantization & computation through cublas in fp16*fp16 for large shapes: https://github.com/turboderp/exllamav2/blob/0e9d9c10101f8faaa69647bec2a517aa4e06f715/exllamav2/exllamav2_ext/cuda/q_gemm.cu#L123-L132. But still, dequantization + cublas call should be slower? cc @turboderp is this weird result?

think from 2k onwards, the most reliable solution is probably to just decompress and then run the FP16 kernel

Yes, that's what the exllama kernel does. I'm just not sure how interesting it is in terms of memory since it allocates a buffer for the fp16 decompressed weight.

Would be fun to compare to https://github.com/NVIDIA/cutlass/releases/tag/v3.4.0 as well

@turboderp
Copy link

The dequantization pass is quite fast, completely memory-bound so about the same latency as matmul at q_len 1. And the extra global memory access is going to be offset by the subsequent cuBLAS matmul having most or all of those FP16 weights ready in the L2 cache. Overall, it's a pretty small overhead on top of cuBLAS which I expect would be hard to even measure for m > 2000 or so.

Needing the extra buffer is a sad tradeoff, but at least you can preallocate it from a pool shared between all the linear layers on the same device, so it's only like 40 MB in total for a 7B model.

As for why the blue and green lines don't converge, I can only imagine the green line would be Torch matmul, which is apparently just a bit slower than cuBLAS?

@fxmarty
Copy link
Contributor Author

fxmarty commented Feb 2, 2024

Ohoh, after investigating a bit, PyTorch picks e.g. cutlass_80_tensorop_f16_s16816gemm_relu_f16_256x128_32x3_nn_align8 for a fp16 @ fp16 GEMM, which does the accumulation in fp32. The cublasHgemm call in exllamav2 does the accumulation in fp16. So the difference in timing between PyTorch & exllamav2 could be the accumulation dtype? Probably it is safer to go with fp32 / cublasSgemmEx.

I wonder if this could not be the cause of some of the nan issues we've had.

@efrantar It seems marlin uses fp32 accumulation as we have using FragC = Vec<float, 4>;?

// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 output/accumulation.
__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, FragC& frag_c) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "

@efrantar
Copy link
Member

efrantar commented Feb 2, 2024

Interesting; yes, Marlin accumulates in fp32. While the tensor cores should have the same speed for fp32 and fp16 accumulation, the latter allows larger tiles (since the accumulators require less registers) which could indeed be a bit faster for very large shapes. I am not sure if fp16 accumulation is always fully safe, hence I went for FP32 with Marlin.

That being said, how exactly are you benchmarking? I have found benchmarking with unlocked clocks to be pretty messy (especially once operations become more expensive and may be throttled), e.g. results could differ significantly depending on order, number of repetitions and wait times (I would really like to know a good solution for this if there is one). In my A100 setup above at batchsize 8k, torch takes ~4.3ms which seems to essentially match your exllamav2 and awq results, so maybe there is just something weird with the benchmarks?

@fxmarty
Copy link
Contributor Author

fxmarty commented Feb 5, 2024

@efrantar I use something along the lines of

                times = []
                for _ in range(num_runs):
                    start_event = torch.cuda.Event(enable_timing=True)
                    end_event = torch.cuda.Event(enable_timing=True)

                    torch.cuda.synchronize()
                    start_event.record()
                    res = forward_handle(inp)

                    end_event.record()

                    torch.cuda.synchronize()

                    tps = (start_event.elapsed_time(end_event))
                    times.append(tps)
                mean_latency = np.mean(times)

Unfortunately not on a clock-locked GPU. Should do that indeed.

(Edit: simply trying cublasHgemm vs cublasGemmEx with fp32 accum dtype, again without locked clock, the latency difference between the two can be significant for large shapes, ~10 - 20%)

@fxmarty fxmarty closed this as completed Feb 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants