Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support CUDA without cuBLAS #82

Merged
merged 12 commits into from
Dec 12, 2023
Merged

Support CUDA without cuBLAS #82

merged 12 commits into from
Dec 12, 2023

Conversation

mrdomino
Copy link
Contributor

@mrdomino mrdomino commented Dec 10, 2023

Introduces a tinyblas library with naive CUDA implementations of the few remaining cublas operations used in llama.cpp/ggml-cuda.cu. Produces the same results with LLaVA at temperature 0 on the prompt I tried. Saves about 500MB of dependencies, but runs about 6x slower (but still quite a bit faster than CPU) on my machine.

The new mode is gated behind the GGML_USE_TINYBLAS cpp define.

Numbers match cublas, but using this code leads to LLaVA outputting
nothing but white squares.
@jart
Copy link
Collaborator

jart commented Dec 11, 2023

Here are some quick numbers using a GCE VM with a Xeon and NVIDIA L4.

  • llava goes 9.46 tokens per second with cpu (llama.cpp)
  • llava goes 9.26 tokens per second with cpu (llamafile)
  • llava goes 48 tokens per second with cublas (llama.cpp)
  • llava goes 17 tokens per second with tinyblas (llamafile)

So your TINYBLAS cublasGemmStridedBatchedEx() implementation goes 2x faster than CPU ✔️ even though it's currently 2x slower than cuBLAS. That's a step in the the right direction IMHO. On systems like Windows, we can only have cuBLAS if we ask the user to install both CUDA and MSVC so it might as well be fairy dust. TINYBLAS gives us a better "just works" fallback path.

Your TINYBLAS library doesn't increase the ~/.llamafile/ggml-cuda.dll DSO size by much, whose size built with nvcc -arch=all is only 8.9mb ✔️, That lets us continue to squeak under the 4gb Windows .exe size limit in the 30mb of space we have remaining.

The output of o//llama.cpp/main/main -m ~/weights/llava-v1.5-7b-Q4_K.gguf --temp 0 -p hello continues to be identical in both GPU and CPU mode under this branch, so I'd assume your implementation is identical to cuBLAS ✔️,. Check check check.

The only issue remaining is that the -lcublas flag is still needed ❌. How easily can we get rid of these?

readelf -Wa ~/.llamafile/ggml-cuda.so | grep -i cublas | c++filt | grep -Po '(?<=UND )\S+' | cat | sort -u
cublasCreate_v2@libcublas.so.12
cublasGemmBatchedEx@libcublas.so.12
cublasGemmEx@libcublas.so.12
cublasGetStatusString@libcublas.so.12
cublasGetStream_v2@libcublas.so.12
cublasSetMathMode@libcublas.so.12
cublasSetStream_v2@libcublas.so.12
cublasSgemm_v2@libcublas.so.12

See also https://www.cs.utexas.edu/~flame/pubs/GotoTOMS_revision.pdf for reading material on how to create a better-than-naive matrix multiplication function. Lastly your work might be of interest to ggerganov/ggml#293.

@jart
Copy link
Collaborator

jart commented Dec 11, 2023

I can get 25 tokens per second by slightly changing this PR to inline the constant parameters:

diff --git a/llama.cpp/naive-gemm.cu b/llama.cpp/naive-gemm.cu
index 82edfe9..4647b6b 100644
--- a/llama.cpp/naive-gemm.cu
+++ b/llama.cpp/naive-gemm.cu
@@ -1,3 +1,5 @@
+// -*- cuda -*-
+
 #include <cuda_runtime.h>
 #include <cuda_fp16.h>
 #include <cublas_v2.h>
@@ -6,9 +8,7 @@
 #define READ0(A, trans, ld, i, j) \
   (((trans) == CUBLAS_OP_N) ? (A)[(i) + (j) * (ld)] : (A)[(j) + (i) * (ld)])
 #define READ(A, type, trans, ld, i, j) \
-  ((type) == CUDA_R_16F                                         \
-   ? __half2float(READ0((half *)(A), (trans), (ld), (i), (j)))  \
-   : READ0((float *)(A), (trans), (ld), (i), (j)))
+  __half2float(READ0((half *)(A), (trans), (ld), (i), (j)))

 static __device__ __forceinline__ void matmul(cublasOperation_t transa,
                                               cublasOperation_t transb,
@@ -28,17 +28,11 @@ static __device__ __forceinline__ void matmul(cublasOperation_t transa,
     for (int j = 0; j < n; ++j) {
       float sum = 0.0;
       for (int l = 0; l < k; ++l) {
-        sum += READ(A, Atype, transa, lda, i, l) *
-               READ(B, Btype, transb, ldb, l, j);
-      }
-      if (Ctype == CUDA_R_16F) {
-        half *cptr = (half *)C + i + ldc * j;
-        *cptr = __float2half(MULZERO(alpha, sum) +
-                             MULZERO(beta, __half2float(*cptr)));
-      } else {
-        float *cptr = (float *)C + i + ldc * j;
-        *cptr = MULZERO(alpha, sum) + MULZERO(beta, *cptr);
+        sum += READ(A, Atype, CUBLAS_OP_T, lda, i, l) *
+               READ(B, Btype, CUBLAS_OP_N, ldb, l, j);
       }
+      half *cptr = (half *)C + i + ldc * j;
+      *cptr = __float2half(sum);
     }
   }
 }

@mrdomino
Copy link
Contributor Author

The only issue remaining is that the -lcublas flag is still needed ❌. How easily can we get rid of these?

For the other gemm routines, should be easy. For the rest, I'm not sure, but I'll see what I can do. Afaict the ones that are impactful are Create and Get / SetStream.

I can get 25 tokens per second by slightly changing this PR to inline the constant parameters:

That's surprising, I would've expected more of the branch predictor! Maybe it's worth doing template specializations after all.

@jart
Copy link
Collaborator

jart commented Dec 11, 2023

Template specialization would be good. It would also be perfectly acceptable to say:

  if (Atype != CUDA_R_16F || Btype != CUDA_R_16F || Ctype != CUDA_R_16F ||
      transa != CUBLAS_OP_T || transb != CUBLAS_OP_N ||
      computeType != CUBLAS_COMPUTE_16F ||
      __half2float(*(half *)pBeta) != 0.0f ||
      __half2float(*(half *)pAlpha) != 1.0f) {
    return CUBLAS_STATUS_NOT_SUPPORTED;
  }

Since that's the only way GGML currently uses this API.

@mrdomino
Copy link
Contributor Author

mrdomino commented Dec 11, 2023

I hardcoded it to the GGML use case, added a very naive and slow cublasGemmEx (that somehow still manages to beat my CPU by a bit?) and a cublasSgemm that I haven't managed to test yet. Down to just these now:

cublasCreate_v2@libcublas.so.12
cublasGemmBatchedEx@libcublas.so.12
cublasGetStatusString@libcublas.so.12
cublasGetStream_v2@libcublas.so.12
cublasSetMathMode@libcublas.so.12
cublasSetStream_v2@libcublas.so.12

Uses some fairly disgusting preprocessor macros to get the job done
while preserving behavior when `-DGGML_USE_CUBLAS`. With a bit of
investigation into `ggml_cuda_mul_mat_mat_batched_cublas`, these can
probably be removed or simplified.
@mrdomino
Copy link
Contributor Author

At this point there are no remaining cublas dependencies when compiled with -DGGML_USE_NAIVE.

@mrdomino mrdomino marked this pull request as ready for review December 12, 2023 17:56
@mrdomino mrdomino marked this pull request as draft December 12, 2023 17:57
N.B. we include the source file rather than the header file in
`ggml-cuda.cu` because `llamafile/cuda.c` assumes that everything lives
in a single compilation unit.
@mrdomino
Copy link
Contributor Author

The header dependency on cublas_v2.h has been removed.

@mrdomino mrdomino marked this pull request as ready for review December 12, 2023 19:40
@mrdomino
Copy link
Contributor Author

My inclination is to do performance improvements in another PR, and I'm not sure yet how you want to decide whether to link against cublas or not. So this PR is done on my end, pending review.

@mrdomino mrdomino changed the title wip naive cublasGemmStridedBatchedEx Support CUDA without cublas Dec 12, 2023
Copy link
Collaborator

@jart jart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested on Jetson and NVIDIA L4 GCE. Confirmed it doesn't link cuBLAS and goes significantly faster than CPU inference. I can add the code for compilation fallback. Looking forward to any additional performance improvements you can send us in a subsequent PR. Thank you!

@jart jart merged commit 72e1c72 into Mozilla-Ocho:main Dec 12, 2023
@mrdomino mrdomino deleted the naive branch December 12, 2023 20:28
@mrdomino mrdomino changed the title Support CUDA without cublas Support CUDA without cuBLAS Dec 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants