Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ROCm support #756

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
24 changes: 23 additions & 1 deletion Makefile
Expand Up @@ -6,17 +6,32 @@ GPP:= /usr/bin/g++
ifeq ($(CUDA_HOME),)
CUDA_HOME:= $(shell which nvcc | rev | cut -d'/' -f3- | rev)
endif
ifeq ($(ROCM_HOME),)
ROCM_HOME:= $(shell which hipcc | rev | cut -d'/' -f3- | rev)
endif

ifneq ($(CUDA_HOME),)
ifndef CUDA_VERSION
ifneq ($(MAKECMDGOALS),clean)
$(warning WARNING: CUDA_VERSION not set. Call make with CUDA string, for example: make cuda11x CUDA_VERSION=115 or make cpuonly CUDA_VERSION=CPU)
CUDA_VERSION:=
endif
endif

else ifneq ($(ROCM_HOME),)
ifndef ROCM_TARGET
$(error ERROR: ROCM_TARGET not set. Call make with ROCM string (see https://www.llvm.org/docs/AMDGPUUsage.html#processors), for example: make hip ROCM_TARGET=gfx1030)
ROCM_TARGET:=
endif
else
$(warning WARNING: Unable to find hipcc in path, fallback to ROCM_HOME /opt/rocm)
ROCM_HOME:=/opt/rocm
endif



NVCC := $(CUDA_HOME)/bin/nvcc
HIPCC:= $(ROCM_HOME)/bin/hipcc

###########################################

Expand All @@ -28,7 +43,8 @@ FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c

INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/include
LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcusparse -L $(CONDA_PREFIX)/lib

HIP_INCLUDE := -I $(ROCM_HOME)/include -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/include
HIP_LIB := -L $(ROCM_HOME)/lib -lhipblas -lhiprand -lhipsparse #-lhipblaslt, currently only gfx90a
# NVIDIA NVCC compilation flags
COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell
COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell
Expand Down Expand Up @@ -115,6 +131,12 @@ cuda12x: $(BUILD_DIR) env
cpuonly: $(BUILD_DIR) env
$(GPP) -std=c++14 -shared -fPIC -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/include $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cpu.so

hip: $(BUILD_DIR)
$(HIPCC) -std=c++14 -c -fPIC --offload-arch=$(ROCM_TARGET) $(HIP_INCLUDE) -o $(BUILD_DIR)/ops.o -DNO_CUBLASLT -DBNB_USE_HIP $(CSRC)/ops.cu
$(HIPCC) -std=c++14 -c -fPIC --offload-arch=$(ROCM_TARGET) $(HIP_INCLUDE) -o $(BUILD_DIR)/kernels.o -DNO_CUBLASLT -DBNB_USE_HIP $(CSRC)/kernels.cu
# HCC is deprecated, but used by hipBLASlt header. Since blas isn't even used doesn't matter, this is just so that it even compiles
$(GPP) -std=c++14 -D__HIP_PLATFORM_HCC__ -D__HIP_PLATFORM_AMD__ -DBUILD_CUDA -DBNB_USE_HIP -shared -fPIC $(HIP_INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(FILES_CPP) $(HIP_LIB) -o ./bitsandbytes/libbitsandbytes_hip_nohipblaslt.so

env:
@echo "ENVIRONMENT"
@echo "============================"
Expand Down
15 changes: 15 additions & 0 deletions README.md
Expand Up @@ -22,17 +22,32 @@ Python >=3.8. Linux distribution (Ubuntu, MacOS, etc.) + CUDA > 10.0.
In some cases it can happen that you need to compile from source. If this happens please consider submitting a bug report with `python -m bitsandbytes` information. What now follows is some short instructions which might work out of the box if `nvcc` is installed. If these do not work see further below.

Compilation quickstart:

```bash
git clone https://github.com/timdettmers/bitsandbytes.git
cd bitsandbytes
```

For CUDA
```bash
# CUDA_VERSIONS in {110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 120}
# make argument in {cuda110, cuda11x, cuda12x}
# if you do not know what CUDA you have, try looking at the output of: python -m bitsandbytes
CUDA_VERSION=117 make cuda11x
python setup.py install
```

For ROCm
```bash
# Requiers ROCm 5.6+
# Check if your GPU supports Wave32 with rocminfo | grep "Wavefront Size"
# If this doesn't output 32 and instead 64 this library won't work

# Your ROCm target can be found with rocminfo | grep gfx
ROCM_TARGET=gfx1030 make hip
pip install .
```

**Using Int8 inference with HuggingFace Transformers**

```python
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/autograd/_functions.py
Expand Up @@ -224,7 +224,7 @@ def backward(ctx, grad_output):

def supports_igemmlt(device: torch.device) -> bool:
"""check if this device supports the optimized int8 kernel"""
if torch.cuda.get_device_capability(device=device) < (7, 5):
if torch.cuda.get_device_capability(device=device) < (7, 5) or torch.version.hip:
return False
device_name = torch.cuda.get_device_name(device=device)
nvidia16_models = ('GTX 1630', 'GTX 1650', 'GTX 1660') # https://en.wikipedia.org/wiki/GeForce_16_series
Expand Down
2 changes: 2 additions & 0 deletions bitsandbytes/cuda_setup/main.py
Expand Up @@ -338,7 +338,9 @@ def evaluate_cuda_setup():
cuda_setup.add_log_entry(('Welcome to bitsandbytes. For bug reports, please run\n\npython -m bitsandbytes\n\n'),
('and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues'))
cuda_setup.add_log_entry('='*80)

if not torch.cuda.is_available(): return 'libbitsandbytes_cpu.so', None, None, None
if torch.version.hip: return 'libbitsandbytes_hip_nohipblaslt.so', None, None, None

cudart_path = determine_cuda_runtime_lib_path()
ccs = get_compute_capabilities()
Expand Down
13 changes: 13 additions & 0 deletions compile_from_source.md
Expand Up @@ -38,3 +38,16 @@ If you have problems compiling the library with these instructions from source,

Since 0.39.1 bitsandbytes installed via pip no longer provides Kepler binaries and these need to be compiled from source. Follow the steps above and instead of `cuda11x_nomatmul` etc use `cuda11x_nomatmul_kepler`

## Compilation with ROCm

Since this library requires hipblasLt this only supports **ROCm 5.6+**.
Works well with these docker images:
- [rocm/pytorch](https://hub.docker.com/r/rocm/pytorch)
- [rocm/pytorch-nightly](https://hub.docker.com/r/rocm/pytorch-nightly).

For installation do:
```bash
make hip ROCM_TARGET=gfx1030
pip install .
```
see https://www.llvm.org/docs/AMDGPUUsage.html#processors for finding ROCM_TARGET (e.g. gfx1030 for 6800XT,6900XT) or do `rocminfo | grep gfx`.
62 changes: 43 additions & 19 deletions csrc/kernels.cu
Expand Up @@ -4,25 +4,41 @@
// LICENSE file in the root directory of this source tree.

#include <kernels.cuh>

#ifdef BNB_USE_HIP
#include <hip/hip_runtime.h>
#include <hipcub/hipcub.hpp>
#include <hipcub/block/block_radix_sort.hpp>
#include <hipcub/warp/warp_reduce.hpp>
#include <hipcub/block/block_load.hpp>
#include <hipcub/block/block_discontinuity.hpp>
#include <hipcub/block/block_store.hpp>
#include <hipcub/block/block_reduce.hpp>
#include <hip/hip_math_constants.h>
#define cub hipcub
#define __syncwarp __syncthreads //HIP doesn't have this, so just sync threads

#else
#include <math_constants.h>
#include <mma.h>
#include <cub/block/block_radix_sort.cuh>
#include <cub/warp/warp_reduce.cuh>
#include <cub/block/block_load.cuh>
#include <cub/block/block_discontinuity.cuh>
#include <cub/block/block_store.cuh>
#include <cub/block/block_reduce.cuh>
#include <cub/cub.cuh>
#include <math_constants.h>
#endif

#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <mma.h>


#define HLF_MAX 65504
#define TH 1024
#define NUM 4
#define NUM_BLOCK 4096


#ifndef BNB_USE_HIP
// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda
__device__ float atomicMax(float* address, float val) {
int* address_as_i = reinterpret_cast<int*>(address);
Expand All @@ -47,6 +63,7 @@ __device__ float atomicMin(float* address, float val) {
} while (assumed != old);
return __int_as_float(old);
}
#endif

__device__ float dDequantizeFP4(unsigned char val, float absmax)
{
Expand Down Expand Up @@ -723,21 +740,28 @@ template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TY
//__launch_bounds__(TH, 4)
__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n)
{
#ifdef BNB_USE_HIP
const int CUB_NUM_PER_TH=(BLOCK_SIZE/NUM_PER_TH % __AMDGCN_WAVEFRONT_SIZE == 0) ? NUM_PER_TH : NUM_PER_TH/2;
#else
const int CUB_NUM_PER_TH=NUM_PER_TH;
#endif
const int DATA_NUM_PER_TH=(DATA_TYPE > 0) ? NUM_PER_TH/2 : CUB_NUM_PER_TH;

const int n_full = gridDim.x * BLOCK_SIZE;
int valid_items = 0;
const int base_idx = (blockIdx.x * BLOCK_SIZE);

T vals[NUM_PER_TH];
float rand_vals[NUM_PER_TH];
unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH];
T vals[CUB_NUM_PER_TH];
float rand_vals[CUB_NUM_PER_TH];
unsigned char qvals[DATA_NUM_PER_TH];
//float local_abs_max = -FLT_MAX;
float local_abs_max = 0.0f;
int local_rand_idx = 0;

typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef cub::BlockStore<unsigned char, BLOCK_SIZE/NUM_PER_TH, (DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_PER_TH> BlockReduce;
typedef cub::BlockLoad<float, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef cub::BlockLoad<T, BLOCK_SIZE/CUB_NUM_PER_TH, CUB_NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef cub::BlockStore<unsigned char, BLOCK_SIZE/CUB_NUM_PER_TH, DATA_NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef cub::BlockReduce<float, BLOCK_SIZE/CUB_NUM_PER_TH> BlockReduce;
typedef cub::BlockLoad<float, BLOCK_SIZE/CUB_NUM_PER_TH, CUB_NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;

__shared__ typename LoadT::TempStorage loadt;
__shared__ typename LoadFloat::TempStorage loadf;
Expand All @@ -762,8 +786,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
// 2. broadcast local max
// 3. normalize inputs and quantize

#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++)
#pragma unroll CUB_NUM_PER_TH
for(int j = 0; j < CUB_NUM_PER_TH; j++)
local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j]));

local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, cub::Max(), valid_items);
Expand Down Expand Up @@ -792,8 +816,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
switch(DATA_TYPE)
{
case General8bit:
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++)
#pragma unroll CUB_NUM_PER_TH
for(int j = 0; j < CUB_NUM_PER_TH; j++)
{
if(!STOCHASTIC)
qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max);
Expand All @@ -802,17 +826,17 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
}
break;
case FP4:
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH/2; j++)
#pragma unroll CUB_NUM_PER_TH
for(int j = 0; j < DATA_NUM_PER_TH; j++)
{
packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4;
packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max);
qvals[j] = packed_4bit;
}
break;
case NF4:
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH/2; j++)
#pragma unroll CUB_NUM_PER_TH
for(int j = 0; j < DATA_NUM_PER_TH; j++)
{
packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4;
packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max);
Expand Down
7 changes: 6 additions & 1 deletion csrc/ops.cu
Expand Up @@ -5,12 +5,17 @@

#include <ops.cuh>
#include <kernels.cuh>
#include <cub/device/device_scan.cuh>
#include <limits>
#include <BinSearch.h>
#include <cassert>
#include <common.h>

#ifdef BNB_USE_HIP
#include <hipcub/device/device_scan.hpp>
#else
#include <cub/device/device_scan.cuh>
#endif


using namespace BinSearch;
using std::cout;
Expand Down
50 changes: 48 additions & 2 deletions csrc/ops.cuh
Expand Up @@ -12,16 +12,62 @@
#include <unistd.h>
#include <assert.h>


#ifdef BNB_USE_HIP

#include <hip/hip_runtime_api.h>
#include <hip/hip_fp16.h>
#include <hipblas/hipblas.h>
#include <hipblaslt/hipblaslt.h> //only using header to allow redefines
#include <hipsparse/hipsparse.h>

#define cudaPeekAtLastError hipPeekAtLastError
#define cudaMemset hipMemset
#define cudaMemAttachHost hipMemAttachHost
#define cudaMemPrefetchAsync hipMemPrefetchAsync
#define cudaMallocManaged hipMallocManaged
#define cudaDevAttrConcurrentManagedAccess hipDeviceAttributeConcurrentManagedAccess
#define cudaDeviceGetAttribute hipDeviceGetAttribute
#define cublasGemmEx hipblasGemmEx
#define cublasStatus_t hipblasStatus_t
#define CUBLAS_OP_T HIPBLAS_OP_T
#define CUBLAS_OP_N HIPBLAS_OP_N
#define CUDA_R_8I HIPBLAS_R_8I
#define CUDA_R_32I HIPBLAS_R_32I
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define cublasStatus_t hipblasStatus_t
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
#define cublasOperation_t hipblasOperation_t
#define cublasLtMatrixLayoutCreate hipblasLtMatrixLayoutCreate
#define cudaError_t hipError_t
#define cudaGetErrorString hipGetErrorString
#define cudaSuccess hipSuccess
#define cusparseStatus_t hipsparseStatus_t
#define CUSPARSE_STATUS_SUCCESS HIPSPARSE_STATUS_SUCCESS
#define cublasStatus_t hipblasStatus_t
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define cublasHandle_t hipblasHandle_t
#define cublasCreate_v2 hipblasCreate
#define cusparseHandle_t hipsparseHandle_t
#define cusparseCreate hipsparseCreate
#define __nv_bfloat16 hip_bfloat16
#define cublasLtHandle_t hipblasLtHandle_t
#define cublasLtCreate hipblasLtCreate
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT

#else
#include <cuda_runtime_api.h>
#include <cuda_fp16.h>
#include <cublas_v2.h>
#include <cublasLt.h>
#include <cusparse.h>
#include <vector>
#include <functional>

#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#endif
#include <vector>
#include <functional>



Expand Down
16 changes: 8 additions & 8 deletions include/Algo-Direct2.h
Expand Up @@ -93,8 +93,8 @@ struct AlgoVecBase<I, T, A, typename std::enable_if<DirectAux::IsDirect2<A>::val
__m128 vxp = _mm_shuffle_ps(xp01, xp23, (1) + (3 << 2) + (1 << 4) + (3 << 6));
#endif
IVec<SSE, float> i(u.vec);
IVec<SSE, float> vlem = vz < vxm;
IVec<SSE, float> vlep = vz < vxp;
IVec<SSE, float> vlem = operator< (vz,vxm);
IVec<SSE, float> vlep = operator< (vz,vxp);
i = i + vlem + vlep;
i.store(pr);
}
Expand Down Expand Up @@ -123,8 +123,8 @@ struct AlgoVecBase<I, T, A, typename std::enable_if<DirectAux::IsDirect2<A>::val
__m128d vxp = _mm_shuffle_pd(vx0, vx1, 3);

IVec<SSE, double> i(b1, b0);
IVec<SSE, double> vlem = (vz < vxm);
IVec<SSE, double> vlep = (vz < vxp);
IVec<SSE, double> vlem = operator< (vz, vxm);
IVec<SSE, double> vlep = operator< (vz, vxp);
i = i + vlem + vlep;

union {
Expand Down Expand Up @@ -227,8 +227,8 @@ struct AlgoVecBase<I, T, A, typename std::enable_if<DirectAux::IsDirect2<A>::val

#endif

IVec<AVX, float> vlem = vz < vxm;
IVec<AVX, float> vlep = vz < vxp;
IVec<AVX, float> vlem = operator< (vz, vxm);
IVec<AVX, float> vlep = operator< (vz, vxp);
ip = ip + vlem + vlep;

ip.store(pr);
Expand Down Expand Up @@ -277,8 +277,8 @@ struct AlgoVecBase<I, T, A, typename std::enable_if<DirectAux::IsDirect2<A>::val
// FVec<AVX, double> vxp = _mm256_insertf128_pd(_mm256_castpd128_pd256(h01p), h23p, 1);

IVec<AVX, double> i(u.vec);
IVec<AVX, double> vlem = vz < vxm;
IVec<AVX, double> vlep = vz < vxp;
IVec<AVX, double> vlem = operator< (vz,vxm);
IVec<AVX, double> vlep = operator< (vz,vxp);
i = i + vlem + vlep;
i.extractLo32s().store(pr);
}
Expand Down