diff --git a/CMakeLists.txt b/CMakeLists.txt index bf1491bf3..1dc9911f8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -32,6 +32,7 @@ option(GGML_NO_ACCELERATE "ggml: disable Accelerate framework" OFF) option(GGML_OPENBLAS "ggml: use OpenBLAS" OFF) option(GGML_CLBLAST "ggml: use clBLAST" OFF) option(GGML_CUBLAS "ggml: use cuBLAS" OFF) +option(GGML_MUSA "ggml: use MUSA" OFF) option(GGML_METAL "ggml: use Metal" OFF) # sanitizers diff --git a/cmake/CMakeDetermineMUSACompiler.cmake b/cmake/CMakeDetermineMUSACompiler.cmake new file mode 100644 index 000000000..35dadda54 --- /dev/null +++ b/cmake/CMakeDetermineMUSACompiler.cmake @@ -0,0 +1,11 @@ + +set(CMAKE_MUSA_ARCHITECTURES "mp_${MUSA_ARCH}") +set(CMAKE_MUSA_COMPILER "${MUSA_MCC}") +set(CMAKE_MUSA_COMPILER_ID "Clang") +set(CMAKE_MUSA_COMPILER_ARG1 "") +set(CMAKE_MUSA_COMPILER_ENV_VAR "MCC") + +configure_file( + ${CMAKE_CURRENT_LIST_DIR}/CMakeMUSACompiler.cmake.in + ${CMAKE_PLATFORM_INFO_DIR}/CMakeMUSACompiler.cmake +) diff --git a/cmake/CMakeMUSACompiler.cmake.in b/cmake/CMakeMUSACompiler.cmake.in new file mode 100644 index 000000000..cd611c4dc --- /dev/null +++ b/cmake/CMakeMUSACompiler.cmake.in @@ -0,0 +1,6 @@ +set(CMAKE_MUSA_COMPILER "@CMAKE_MUSA_COMPILER@") +set(CMAKE_MUSA_COMPILER_ARG1 "@CMAKE_MUSA_COMPILER_ARG1@") +set(CMAKE_MUSA_COMPILER_LOADED 1) +set(CMAKE_MUSA_SOURCE_FILE_EXTENSIONS mu;cu) +set(CMAKE_MUSA_OUTPUT_EXTENSION .o) +set(CMAKE_MUSA_COMPILER_ENV_VAR "MUSA") diff --git a/cmake/CMakeMUSAInformation.cmake b/cmake/CMakeMUSAInformation.cmake new file mode 100644 index 000000000..bb0244d8e --- /dev/null +++ b/cmake/CMakeMUSAInformation.cmake @@ -0,0 +1,26 @@ + +# reuse cxx things + +include(CMakeLanguageInformation) +include(CMakeCommonLanguageInclude) + +include(Compiler/Clang) + +__compiler_clang(MUSA) +__compiler_clang_cxx_standards(MUSA) + +set(CMAKE_INCLUDE_FLAG_MUSA "-I") + +set(CMAKE_MUSA_RUNTIME_LIBRARY_DEFAULT "SHARED") +set(CMAKE_MUSA_RUNTIME_LIBRARY_LINK_OPTIONS_STATIC "") +set(CMAKE_MUSA_RUNTIME_LIBRARY_LINK_OPTIONS_SHARED "") + +# Populated by CMakeHIPInformation.cmake +set(CMAKE_MUSA_RUNTIME_LIBRARIES_STATIC "") +set(CMAKE_MUSA_RUNTIME_LIBRARIES_SHARED "") + +# compile a C++ file into an object file +if(NOT CMAKE_MUSA_COMPILE_OBJECT) + set(CMAKE_MUSA_COMPILE_OBJECT + " -x musa --cuda-gpu-arch=${CMAKE_MUSA_ARCHITECTURES} -fPIC -o -c ") +endif() diff --git a/cmake/CMakeTestMUSACompiler.cmake b/cmake/CMakeTestMUSACompiler.cmake new file mode 100644 index 000000000..6bc32198d --- /dev/null +++ b/cmake/CMakeTestMUSACompiler.cmake @@ -0,0 +1 @@ +# do nothing, make cmake happy diff --git a/cmake/FindMUSA.cmake b/cmake/FindMUSA.cmake new file mode 100644 index 000000000..6841cf372 --- /dev/null +++ b/cmake/FindMUSA.cmake @@ -0,0 +1,101 @@ +# find MUSA things + +include(${CMAKE_ROOT}/Modules/FindPackageHandleStandardArgs.cmake) +include(${CMAKE_ROOT}/Modules/SelectLibraryConfigurations.cmake) +include(${CMAKE_ROOT}/Modules/CMakeFindDependencyMacro.cmake) + +if(DEFINED ENV{MUSA_HOME}) + set(MUSA_HOME $ENV{MUSA_HOME}) +else() + set(MUSA_HOME /usr/local/musa) +endif() + +set(MUSA_MCC ${MUSA_HOME}/bin/mcc) + +if (DEFINED ENV{MUSA_ARCH}) + set(MUSA_ARCH $ENV{MUSA_ARCH}) +elseif(NOT MUSA_ARCH) + set(MUSA_ARCH "21") +endif() + +if(NOT MUSA_INCLUDE_DIR) + set(MUSA_INCLUDE_DIR ${MUSA_HOME}/include) +endif() + +if(NOT MUSA_LIBRARY_DIR) + set(MUSA_LIBRARY_DIR ${MUSA_HOME}/lib) +endif() + +if(NOT MUSA_LIBRARIES) + find_library( + MUSA_MUSA_LIBRARY + NAMES libmusa.so + PATHS ${MUSA_LIBRARY_DIR} + ) + + find_library( + MUSA_MUBLAS_LIBRARY + NAMES libmublas.so + PATHS ${MUSA_LIBRARY_DIR} + ) + + find_library( + MUSA_MUSART_LIBRARY + NAMES libmusart.so + PATHS ${MUSA_LIBRARY_DIR} + ) + + if(MUSA_MUSA_LIBRARY AND MUSA_MUBLAS_LIBRARY AND MUSA_MUSART_LIBRARY) + set(MUSA_LIBRARIES "${MUSA_MUSA_LIBRARY};${MUSA_MUBLAS_LIBRARY};${MUSA_MUSART_LIBRARY}") + set(MUSA_MUSA_LIBRARY CACHE STRING "${MUSA_MUSA_LIBRARY}") + set(MUSA_MUBLAS_LIBRARY CACHE STRING "${MUSA_MUBLAS_LIBRARY}") + set(MUSA_MUSART_LIBRARY CACHE STRING "${MUSA_MUSART_LIBRARY}") + endif() +endif() + +if(MUSA_LIBRARIES) + if(NOT TARGET MUSA::musa) + add_library(MUSA::musa SHARED IMPORTED) + set_target_properties(MUSA::musa PROPERTIES + IMPORTED_LOCATION ${MUSA_MUSA_LIBRARY} + INTERFACE_INCLUDE_DIRECTORIES ${MUSA_INCLUDE_DIR} + ) + endif() + + if(NOT TARGET MUSA::mublas) + add_library(MUSA::mublas SHARED IMPORTED) + set_target_properties(MUSA::mublas PROPERTIES + IMPORTED_LOCATION ${MUSA_MUBLAS_LIBRARY} + INTERFACE_INCLUDE_DIRECTORIES ${MUSA_INCLUDE_DIR} + ) + endif() + + if(NOT TARGET MUSA::musart) + add_library(MUSA::musart SHARED IMPORTED) + set_target_properties(MUSA::musart PROPERTIES + IMPORTED_LOCATION ${MUSA_MUSART_LIBRARY} + INTERFACE_INCLUDE_DIRECTORIES ${MUSA_INCLUDE_DIR} + ) + endif() + + set(MUSA_INCLUDE_DIR ${MUSA_INCLUDE_DIR} CACHE STRING "") + set(MUSA_LIBRARY_DIR ${MUSA_LIBRARY_DIR} CACHE STRING "") + set(MUSA_LIBRARIES ${MUSA_LIBRARIES} CACHE STRING "") +endif() + +find_package_handle_standard_args( + MUSA + REQUIRED_VARS + MUSA_ARCH + MUSA_MCC + MUSA_INCLUDE_DIR + MUSA_LIBRARIES + MUSA_LIBRARY_DIR +) +mark_as_advanced( + MUSA_INCLUDE_DIR + MUSA_LIBRARIES + MUSA_LIBRARY_DIR + CMAKE_MUSA_ARCHITECTURES + CMAKE_MUSA_COMPILER +) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 57287e41d..e9c1cdabb 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -217,6 +217,25 @@ if (GGML_CUBLAS) endif() endif() +if (GGML_MUSA) + option(MUSA_ARCH "MUSA architecture" "21") + + list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/../cmake") + + find_package(MUSA REQUIRED) + + message(STATUS "MUSA found") + + enable_language(MUSA) + + add_compile_definitions(GGML_USE_MUSA GGML_USE_CUDA) + + set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h) + set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE MUSA) + + set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} PUBLIC MUSA::musa MUSA::mublas MUSA::musart) +endif() + if (GGML_METAL) find_library(FOUNDATION_LIBRARY Foundation REQUIRED) find_library(METAL_FRAMEWORK Metal REQUIRED) diff --git a/src/ggml-cuda.cu b/src/ggml-cuda.cu index e0163ae0c..49b0399ed 100644 --- a/src/ggml-cuda.cu +++ b/src/ggml-cuda.cu @@ -64,6 +64,10 @@ #define cudaStreamWaitEvent(stream, event) hipStreamWaitEvent(stream, event, 0) #define cudaStream_t hipStream_t #define cudaSuccess hipSuccess +#elif defined(GGML_USE_MUSA) +#include +#include +#include "musa_compatible.cuh" #else #include #include diff --git a/src/musa_compatible.cuh b/src/musa_compatible.cuh new file mode 100644 index 000000000..bfcc7a0fb --- /dev/null +++ b/src/musa_compatible.cuh @@ -0,0 +1,204 @@ + +#ifndef _MUSA_COMPATIBLE_CUH +#define _MUSA_COMPATIBLE_CUH + + +#define CUresult MUresult +#define CUdevice MUdevice +#define CUdeviceptr MUdeviceptr + +#define cudaDataType_t musaDataType_t +#define cudaError_t musaError_t +#define cudaEvent_t musaEvent_t +#define cudaStream_t musaStream_t +#define cudaDeviceProp musaDeviceProp + +#define cublasStatus_t mublasStatus_t +#define cublasHandle_t mublasHandle_t +#define cublasComputeType_t musaDataType_t // reserved in musa + +#define cuGetErrorString muGetErrorString +#define cuDeviceGet muDeviceGet +#define cuDeviceGetAttribute muDeviceGetAttribute +// #define cuMemGetAllocationGranularity muMemGetAllocationGranularity // so far, not implemeted +// #define CUmemAllocationProp MUmemAllocationProp + +#define cudaGetErrorString musaGetErrorString +#define cudaGetLastError musaGetLastError +#define cudaMemGetInfo musaMemGetInfo +#define cudaMemcpy musaMemcpy +#define cudaMemcpyKind musaMemcpyKind +#define cudaMemset musaMemset +#define cudaMalloc musaMalloc +#define cudaMallocHost musaMallocHost +#define cudaFree musaFree +#define cudaFreeHost musaFreeHost +#define cudaHostUnregister musaHostUnregister +#define cudaMemcpyAsync musaMemcpyAsync +#define cudaMemcpy2DAsync musaMemcpy2DAsync +#define cudaMemcpyPeerAsync musaMemcpyPeerAsync +#define cudaMemcpyHostToDevice musaMemcpyHostToDevice +#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost +#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice +#define cudaDeviceSynchronize musaDeviceSynchronize +#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer +#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess +#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess +#define cudaGetDevice musaGetDevice +#define cudaGetDeviceCount musaGetDeviceCount +#define cudaGetDeviceProperties musaGetDeviceProperties +#define cudaSetDevice musaSetDevice +#define cudaEventRecord musaEventRecord +#define cudaEventDestroy musaEventDestroy +#define cudaEventCreate musaEventCreate +#define cudaEventSynchronize musaEventSynchronize +#define cudaEventDisableTiming musaEventDisableTiming +#define cudaEventCreateWithFlags musaEventCreateWithFlags +#define cudaStreamPerThread musaStreamPerThread +#define cudaStreamSynchronize musaStreamSynchronize +#define cudaStreamCreateWithFlags musaStreamCreateWithFlags +#define cudaStreamNonBlocking musaStreamNonBlocking +#define cudaStreamDestroy musaStreamDestroy +#define cudaStreamWaitEvent musaStreamWaitEvent + +#define cublasCreate mublasCreate +#define cublasDestroy mublasDestroy +#define cublasSetMathMode mublasSetMathMode +#define cublasSetStream mublasSetStream +#define cublasGemmEx mublasGemmEx +#define cublasSgemm mublasSgemm +#ifdef mublasGemmStridedBatchedEx +#undef mublasGemmStridedBatchedEx +#endif // mublasGemmStridedBatchedEx +#define cublasGemmStridedBatchedEx( \ + handle, \ + transA, \ + transB, \ + m, \ + n, \ + k, \ + alpha, \ + A, \ + Atype, \ + lda, \ + strideA, \ + B, \ + Btype, \ + ldb, \ + strideB, \ + beta, \ + C, \ + Ctype, \ + ldc, \ + strideC, \ + batchCount, \ + computeType, \ + algo \ +) \ +mublasGemmStridedBatchedEx( \ + handle, \ + transA, \ + transB, \ + m, \ + n, \ + k, \ + alpha, \ + A, \ + Atype, \ + lda, \ + strideA, \ + B, \ + Btype, \ + ldb, \ + strideB, \ + beta, \ + C, \ + Ctype, \ + ldc, \ + strideC, \ + C /* D */, \ + Ctype, \ + ldc, \ + strideC, \ + batchCount, \ + computeType, \ + algo, \ + 0 /* solution type, reserved */, \ + 0 /* flags */ \ +) + +#define cublasGemmBatchedEx( \ + handle, \ + transA, \ + transB, \ + m, \ + n, \ + k, \ + alpha, \ + A, \ + Atype, \ + lda, \ + B, \ + Btype, \ + ldb, \ + beta, \ + C, \ + Ctype, \ + ldc, \ + batchCount, \ + computeType, \ + algo \ +) \ +mublasGemmBatchedEx( \ + handle, \ + transA, \ + transB, \ + m, \ + n, \ + k, \ + alpha, \ + A, \ + Atype, \ + lda, \ + B, \ + Btype, \ + ldb, \ + beta, \ + C, \ + Ctype, \ + ldc, \ + C /* D */, \ + Ctype, \ + ldc, \ + batchCount, \ + computeType, \ + algo, \ + 0 /* solution type, reserved */, \ + 0 /* flags */ \ +) + +#define CUDART_VERSION MUSART_VERSION + +#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE +// #define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED +// #define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED +#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED + +#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS +#define CUBLAS_STATUS_NOT_INITIALIZED MUBLAS_STATUS_NOT_IMPLEMENTED +#define CUBLAS_STATUS_ALLOC_FAILED MUBLAS_STATUS_NOT_IMPLEMENTED +#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_TP32_TENSOR // ??? +#define CUBLAS_OP_T MUBLAS_OP_T +#define CUBLAS_OP_N MUBLAS_OP_N +#define CUBLAS_COMPUTE_16F MUSA_R_16F // reserved in musa +#define CUBLAS_COMPUTE_32F MUSA_R_32F // reserved in musa +#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT_TENSOR_OP + +#define CUDA_SUCCESS MUSA_SUCCESS +#define CUDA_R_16F MUSA_R_16F +#define CUDA_R_32F MUSA_R_32F +#define cudaSuccess musaSuccess +#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled +#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled + +#endif // _MUSA_COMPATIBLE_CUH \ No newline at end of file