From 83acb0f7d03258d5ec5c7cfed267569b76369476 Mon Sep 17 00:00:00 2001 From: lizimin Date: Thu, 31 Oct 2024 15:34:19 +0800 Subject: [PATCH 01/10] Add Expand operator --- include/ops/expand/expand.h | 25 +++++ operatorspy/tests/expand.py | 180 +++++++++++++++++++++++++++++++ src/ops/expand/cpu/expand_cpu.cc | 67 ++++++++++++ src/ops/expand/cpu/expand_cpu.h | 28 +++++ src/ops/expand/cuda/expand.cc | 55 ++++++++++ src/ops/expand/cuda/expand.cu | 53 +++++++++ src/ops/expand/cuda/expand.cuh | 34 ++++++ src/ops/expand/operator.cc | 72 +++++++++++++ 8 files changed, 514 insertions(+) create mode 100644 include/ops/expand/expand.h create mode 100644 operatorspy/tests/expand.py create mode 100644 src/ops/expand/cpu/expand_cpu.cc create mode 100644 src/ops/expand/cpu/expand_cpu.h create mode 100644 src/ops/expand/cuda/expand.cc create mode 100644 src/ops/expand/cuda/expand.cu create mode 100644 src/ops/expand/cuda/expand.cuh create mode 100644 src/ops/expand/operator.cc diff --git a/include/ops/expand/expand.h b/include/ops/expand/expand.h new file mode 100644 index 00000000..ee28b70c --- /dev/null +++ b/include/ops/expand/expand.h @@ -0,0 +1,25 @@ +#ifndef EXPAND_H +#define EXPAND_H + +#include "../../export.h" +#include "../../operators.h" + +typedef struct ExpandDescriptor { + Device device; +} ExpandDescriptor; + +typedef ExpandDescriptor *infiniopExpandDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateExpandDescriptor(infiniopHandle_t handle, + infiniopExpandDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x); + +__C __export infiniopStatus_t infiniopExpand(infiniopExpandDescriptor_t desc, + void *y, + void const *x, + void *stream); + +__C __export infiniopStatus_t infiniopDestroyExpandDescriptor(infiniopExpandDescriptor_t desc); + +#endif diff --git a/operatorspy/tests/expand.py b/operatorspy/tests/expand.py new file mode 100644 index 00000000..fea84d19 --- /dev/null +++ b/operatorspy/tests/expand.py @@ -0,0 +1,180 @@ +from ctypes import POINTER, Structure, c_int32, c_void_p +import ctypes +import sys +import os +import time + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from operatorspy import ( + open_lib, + to_tensor, + DeviceEnum, + infiniopHandle_t, + infiniopTensorDescriptor_t, + create_handle, + destroy_handle, + check_error, + rearrange_tensor, +) + +from operatorspy.tests.test_utils import get_args +import torch + +# constant for control whether profile the pytorch and lib functions +# NOTE: need to manually add synchronization function to the lib function, +# e.g., cudaDeviceSynchronize() for CUDA +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +class ExpandDescriptor(Structure): + _fields_ = [("device", c_int32)] + + +infiniopExpandDescriptor_t = POINTER(ExpandDescriptor) + + +def expand(x, y): + if PROFILE: + ans = x.expand_as(y).clone() + torch.cuda.synchronize() + return ans + return x.expand_as(y) + + +def test( + lib, + handle, + torch_device, + y_shape, + x_shape, + y_stride=None, + x_stride=None, + tensor_dtype=torch.float16, +): + print( + f"Testing Expand on {torch_device} with x_shape:{x_shape} y_shape:{y_shape} x_stride:{x_stride} y_stride:{y_stride} dtype:{tensor_dtype}" + ) + + x = torch.rand(x_shape, dtype=tensor_dtype).to(torch_device) + y = torch.rand(y_shape, dtype=tensor_dtype).to(torch_device) + + if x_stride is not None: + x = rearrange_tensor(x, x_stride) + if y_stride is not None: + y = rearrange_tensor(y, y_stride) + + for i in range(NUM_PRERUN if PROFILE else 1): + ans = expand(x, y) + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + _ = expand(x, y) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f"pytorch time: {elapsed :6f}") + + x_tensor = to_tensor(x, lib) + y_tensor = to_tensor(y, lib) + descriptor = infiniopExpandDescriptor_t() + + check_error( + lib.infiniopCreateExpandDescriptor( + handle, + ctypes.byref(descriptor), + y_tensor.descriptor, + x_tensor.descriptor, + ) + ) + + for i in range(NUM_PRERUN if PROFILE else 1): + lib.infiniopExpand( + descriptor, y_tensor.data, x_tensor.data, None + ) + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + lib.infiniopExpand( + descriptor, y_tensor.data, x_tensor.data, None + ) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f" lib time: {elapsed :6f}") + assert torch.allclose(y, ans, atol=0, rtol=1e-3) + check_error(lib.infiniopDestroyExpandDescriptor(descriptor)) + + +def test_cpu(lib, test_cases): + device = DeviceEnum.DEVICE_CPU + handle = create_handle(lib, device) + for y_shape, x_shape, y_stride, x_stride in test_cases: + test(lib, handle, "cpu", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float16) + test(lib, handle, "cpu", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float32) + destroy_handle(lib, handle) + + +def test_cuda(lib, test_cases): + device = DeviceEnum.DEVICE_CUDA + handle = create_handle(lib, device) + for y_shape, x_shape, y_stride, x_stride in test_cases: + test(lib, handle, "cuda", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float16) + test(lib, handle, "cuda", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float32) + destroy_handle(lib, handle) + + +def test_bang(lib, test_cases): + import torch_mlu + + device = DeviceEnum.DEVICE_BANG + handle = create_handle(lib, device) + for y_shape, x_shape, y_stride, x_stride in test_cases: + test(lib, handle, "mlu", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float16) + test(lib, handle, "mlu", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float32) + destroy_handle(lib, handle) + + +if __name__ == "__main__": + test_cases = [ + # y_shape, x_shape, y_stride, x_stride + ((), (), None, None), + # ((4, 2048), (2048,), (4096, 1), (1,)), + ((3, 3), (1,), None, None), + ((5, 4, 3), (4, 3,), None, (6, 1)), + ((99, 111), (111,), None, None), + ((2, 4, 3), (1, 3), None, None), + ((2, 20, 3), (2, 1, 3), None, None), + ((2, 3, 4, 5), (5,), None, None), + ((3, 2, 4, 5), (3, 2, 1, 1), None, None), + ((32, 256, 112, 112), (32, 256, 112, 1), None, None), + # ((32, 256, 112, 112), (32, 1, 1, 1), None, None), + # ((32, 150, 51200), (32, 150, 1), None, None), + ] + args = get_args() + lib = open_lib() + lib.infiniopCreateExpandDescriptor.restype = c_int32 + lib.infiniopCreateExpandDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopExpandDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + lib.infiniopExpand.restype = c_int32 + lib.infiniopExpand.argtypes = [ + infiniopExpandDescriptor_t, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyExpandDescriptor.restype = c_int32 + lib.infiniopDestroyExpandDescriptor.argtypes = [ + infiniopExpandDescriptor_t, + ] + + if args.cpu: + test_cpu(lib, test_cases) + if args.cuda: + test_cuda(lib, test_cases) + if args.bang: + test_bang(lib, test_cases) + if not (args.cpu or args.cuda or args.bang): + test_cpu(lib, test_cases) + print("\033[92mTest passed!\033[0m") diff --git a/src/ops/expand/cpu/expand_cpu.cc b/src/ops/expand/cpu/expand_cpu.cc new file mode 100644 index 00000000..b5fe2698 --- /dev/null +++ b/src/ops/expand/cpu/expand_cpu.cc @@ -0,0 +1,67 @@ +#include "expand_cpu.h" +#include "../../../devices/cpu/common_cpu.h" +#include "../../utils.h" +#include + +infiniopStatus_t cpuCreateExpandDescriptor(infiniopHandle_t, + ExpandCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x) { + uint64_t ndim = y->ndim; + if (!isValidBroadcastShape(y, x)) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (y->dt != x->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + + uint64_t y_data_size = std::accumulate(y->shape, y->shape + y->ndim, 1ULL, std::multiplies()); + + // get the adjusted strides for x in terms of y + int64_t *x_strides = new int64_t[ndim]; +#pragma omp parallel for + for (size_t i = 0; i < ndim; ++i) { + x_strides[i] = (i < ndim - x->ndim || y->shape[i] != x->shape[i + x->ndim - ndim]) ? 0 : x->strides[i + x->ndim - ndim]; + } + + *desc_ptr = new ExpandCpuDescriptor{ + DevCpu, + y->dt, + ndim, + y_data_size, + x_strides, + y->strides, + }; + + return STATUS_SUCCESS; +} + +infiniopStatus_t cpuDestroyExpandDescriptor(ExpandCpuDescriptor_t desc) { + delete[] desc->x_strides; + delete desc; + return STATUS_SUCCESS; +} + +template +infiniopStatus_t expand_cpu(ExpandCpuDescriptor_t desc, void *y, void const *x) { + auto x_ = reinterpret_cast(x); + auto y_ = reinterpret_cast(y); + +#pragma omp parallel for + for (uint64_t i = 0; i < desc->y_data_size; ++i) { + y_[i] = x_[getDstIndex(i, desc->ndim, desc->y_strides, desc->x_strides)]; + } + return STATUS_SUCCESS; +} + +infiniopStatus_t cpuExpand(ExpandCpuDescriptor_t desc, + void *y, void const *x, + void *stream) { + if (desc->dtype == F16) { + return expand_cpu(desc, y, x); + } + if (desc->dtype == F32) { + return expand_cpu(desc, y, x); + } + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/expand/cpu/expand_cpu.h b/src/ops/expand/cpu/expand_cpu.h new file mode 100644 index 00000000..c1796dc3 --- /dev/null +++ b/src/ops/expand/cpu/expand_cpu.h @@ -0,0 +1,28 @@ +#ifndef __CPU_EXPAND_H__ +#define __CPU_EXPAND_H__ + +#include "operators.h" +#include + +struct ExpandCpuDescriptor { + Device device; + DT dtype; + uint64_t ndim; + uint64_t y_data_size; + int64_t const *x_strides; + int64_t const *y_strides; +}; + +typedef struct ExpandCpuDescriptor *ExpandCpuDescriptor_t; + +infiniopStatus_t cpuCreateExpandDescriptor(infiniopHandle_t, + ExpandCpuDescriptor_t *, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x); + +infiniopStatus_t cpuExpand(ExpandCpuDescriptor_t desc, + void *y, void const *x, void *stream); + +infiniopStatus_t cpuDestroyExpandDescriptor(ExpandCpuDescriptor_t desc); + +#endif diff --git a/src/ops/expand/cuda/expand.cc b/src/ops/expand/cuda/expand.cc new file mode 100644 index 00000000..bd21b34c --- /dev/null +++ b/src/ops/expand/cuda/expand.cc @@ -0,0 +1,55 @@ +#include "expand.cuh" +#include "../../../devices/cuda/common_cuda.h" +#include "../../utils.h" + +infiniopStatus_t cudaCreateExpandDescriptor(CudaHandle_t handle, + ExpandCudaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x) { + uint64_t ndim = y->ndim; + if (!isValidBroadcastShape(y, x)) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (y->dt != x->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + + uint64_t y_data_size = std::accumulate(y->shape, y->shape + y->ndim, 1ULL, std::multiplies()); + + // get the adjusted strides for x in terms of y + int64_t *x_strides = new int64_t[ndim]; + for (size_t i = 0; i < ndim; ++i) { + x_strides[i] = (i < ndim - x->ndim || y->shape[i] != x->shape[i + x->ndim - ndim]) ? 0 : x->strides[i + x->ndim - ndim]; + } + + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, handle->device_id); + + int64_t *x_strides_d, *y_strides_d; + checkCudaErrorWithCode(cudaMalloc(&x_strides_d, ndim * sizeof(int64_t)), STATUS_MEMORY_NOT_ALLOCATED); + checkCudaErrorWithCode(cudaMalloc(&y_strides_d, ndim * sizeof(int64_t)), STATUS_MEMORY_NOT_ALLOCATED); + checkCudaErrorWithCode(cudaMemcpy(x_strides_d, x_strides, ndim * sizeof(int64_t), cudaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaMemcpy(y_strides_d, y->strides, ndim * sizeof(int64_t), cudaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + + *desc_ptr = new ExpandCudaDescriptor{ + DevNvGpu, + y->dt, + handle->device_id, + ndim, + y_data_size, + static_cast(prop.maxGridSize[0]), + x_strides_d, + y_strides_d, + }; + + delete[] x_strides; + + return STATUS_SUCCESS; +} + +infiniopStatus_t cudaDestroyExpandDescriptor(ExpandCudaDescriptor_t desc) { + cudaFree((void *) desc->x_strides); + cudaFree((void *) desc->y_strides); + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/expand/cuda/expand.cu b/src/ops/expand/cuda/expand.cu new file mode 100644 index 00000000..a879fb20 --- /dev/null +++ b/src/ops/expand/cuda/expand.cu @@ -0,0 +1,53 @@ +#include "../../../devices/cuda/common_cuda.h" +#include "../../utils.h" +#include "expand.cuh" + +template +__global__ void expand( + Tdata *y, + const Tdata *x, + const int64_t *y_strides, + const int64_t *x_strides, + uint64_t y_data_size, + uint64_t ndim, + uint64_t offset) { + uint64_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset; + + if (idx < y_data_size) { + y[idx] = x[getDstIndex(idx, ndim, y_strides, x_strides)]; + } +} + +template +infiniopStatus_t expand_nv_gpu(ExpandCudaDescriptor_t desc, void *y, void const *x, void *stream) { + if (desc->y_data_size == 0) { + return STATUS_SUCCESS; + } + dim3 blockDims = dim3(std::min(static_cast(256), desc->y_data_size)); + dim3 gridDims = dim3(std::min(ROUND_UP_DIV(desc->y_data_size, blockDims.x), desc->max_grid_size)); + uint64_t step = gridDims.x * blockDims.x; + + const auto x_ = reinterpret_cast(x); + const auto y_ = reinterpret_cast(y); + cudaStream_t cuda_stream = reinterpret_cast(stream); + +#pragma unroll + for (uint64_t i = 0; i < desc->y_data_size; i += step) { + expand<<>>( + y_, x_, desc->y_strides, desc->x_strides, i + desc->y_data_size, desc->ndim, i); + } + return STATUS_SUCCESS; +} + +infiniopStatus_t cudaExpand(ExpandCudaDescriptor_t desc, + void *y, void const *x, + void *stream) { + checkCudaError(cudaSetDevice(desc->device_id)); + if (desc->dtype == F16) { + return expand_nv_gpu(desc, y, x, stream); + } + if (desc->dtype == F32) { + return expand_nv_gpu(desc, y, x, stream); + } + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/expand/cuda/expand.cuh b/src/ops/expand/cuda/expand.cuh new file mode 100644 index 00000000..2f18a82f --- /dev/null +++ b/src/ops/expand/cuda/expand.cuh @@ -0,0 +1,34 @@ +#ifndef __CUDA_EXPAND_H__ +#define __CUDA_EXPAND_H__ + +#include "../../../devices/cuda/common_cuda.h" +#include "../../../devices/cuda/cuda_handle.h" +#include "operators.h" +#include +#include + +struct ExpandCudaDescriptor { + Device device; + DT dtype; + int device_id; + uint64_t ndim; + uint64_t y_data_size; + uint64_t max_grid_size; + int64_t const *x_strides; + int64_t const *y_strides; +}; + +typedef struct ExpandCudaDescriptor *ExpandCudaDescriptor_t; + +infiniopStatus_t cudaCreateExpandDescriptor(CudaHandle_t, + ExpandCudaDescriptor_t *, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x); + +infiniopStatus_t cudaExpand(ExpandCudaDescriptor_t desc, + void *y, void const *x, + void *stream); + +infiniopStatus_t cudaDestroyExpandDescriptor(ExpandCudaDescriptor_t desc); + +#endif diff --git a/src/ops/expand/operator.cc b/src/ops/expand/operator.cc new file mode 100644 index 00000000..0572acd0 --- /dev/null +++ b/src/ops/expand/operator.cc @@ -0,0 +1,72 @@ +#include "../utils.h" +#include "operators.h" +#include "ops/expand/expand.h" + +#ifdef ENABLE_CPU +#include "cpu/expand_cpu.h" +#endif +#ifdef ENABLE_NV_GPU +#include "../../devices/cuda/cuda_handle.h" +#include "cuda/expand.cuh" +#endif + +__C infiniopStatus_t infiniopCreateExpandDescriptor( + infiniopHandle_t handle, + infiniopExpandDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x) { + switch (handle->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuCreateExpandDescriptor(handle, (ExpandCpuDescriptor_t *) desc_ptr, y, x); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaCreateExpandDescriptor((CudaHandle_t) handle, (ExpandCudaDescriptor_t *) desc_ptr, y, x); + } + +#endif +#ifdef ENABLE_CAMBRICON_MLU + // TODO +#endif + } + return STATUS_BAD_DEVICE; +} + +__C infiniopStatus_t infiniopExpand(infiniopExpandDescriptor_t desc, void *y, void const *x, void *stream) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuExpand((ExpandCpuDescriptor_t) desc, y, x, stream); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaExpand((ExpandCudaDescriptor_t) desc, y, x, stream); + } + +#endif +#ifdef ENABLE_CAMBRICON_MLU + // TODO +#endif + } + return STATUS_BAD_DEVICE; +} + +__C infiniopStatus_t infiniopDestroyExpandDescriptor(infiniopExpandDescriptor_t desc) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuDestroyExpandDescriptor((ExpandCpuDescriptor_t) desc); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaDestroyExpandDescriptor((ExpandCudaDescriptor_t) desc); + } + +#endif +#ifdef ENABLE_CAMBRICON_MLU + // TODO +#endif + } + return STATUS_BAD_DEVICE; +} From 65724c1680ea31e492641908276380f3e74286ce Mon Sep 17 00:00:00 2001 From: lizimin Date: Thu, 31 Oct 2024 15:36:57 +0800 Subject: [PATCH 02/10] Add fp32 support for matmul, move getDstIndex to common utils --- operatorspy/tests/matmul.py | 58 ++++++++++++++++++++++-------- src/devices/cpu/common_cpu.cc | 9 +++++ src/devices/cpu/common_cpu.h | 5 ++- src/devices/cuda/common_cuda.h | 10 ++++++ src/ops/add/cuda/add.cu | 10 ------ src/ops/matmul/cpu/matmul_cpu.cc | 53 ++++++++++++++++----------- src/ops/matmul/cpu/matmul_cpu.h | 2 -- src/ops/matmul/cuda/matmul_cuda.cc | 17 +-------- src/ops/matmul/cuda/matmul_cuda.cu | 50 +++++++++++++++++++++----- src/ops/matmul/cuda/matmul_cuda.h | 2 +- src/ops/utils.h | 16 +++++++++ 11 files changed, 157 insertions(+), 75 deletions(-) diff --git a/operatorspy/tests/matmul.py b/operatorspy/tests/matmul.py index c625f1ce..45a1fb9b 100644 --- a/operatorspy/tests/matmul.py +++ b/operatorspy/tests/matmul.py @@ -2,6 +2,7 @@ import ctypes import sys import os +import time sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) from operatorspy import ( @@ -21,6 +22,13 @@ from operatorspy.tests.test_utils import get_args import torch +# constant for control whether profile the pytorch and lib functions +# NOTE: need to manually add synchronization function to the lib function, +# e.g., cudaDeviceSynchronize() for CUDA +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + class MatmulDescriptor(Structure): _fields_ = [("device", c_int32)] @@ -30,10 +38,13 @@ class MatmulDescriptor(Structure): def matmul(c, beta, a, b, alpha): input_dtype = c.dtype - return ( + ans = ( alpha * torch.matmul(a.to(torch.float32), b.to(torch.float32)).to(input_dtype) + beta * c ) + if PROFILE: + torch.cuda.synchronize() + return ans def test( @@ -66,7 +77,15 @@ def test( if c_stride is not None: c = rearrange_tensor(c, c_stride) - ans = matmul(c, beta, a, b, alpha) + for i in range(NUM_PRERUN if PROFILE else 1): + ans = matmul(c, beta, a, b, alpha) + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + _ = matmul(c, beta, a, b, alpha) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f"pytorch time: {elapsed :6f}") + a_tensor = to_tensor(a, lib) b_tensor = to_tensor(b, lib) @@ -90,7 +109,8 @@ def test( ) workspace = create_workspace(workspace_size.value, a.device) - check_error( + for i in range(NUM_PRERUN if PROFILE else 1): + check_error( lib.infiniopMatmul( descriptor, workspace.data_ptr() if workspace is not None else None, @@ -101,6 +121,20 @@ def test( None, ) ) + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + lib.infiniopMatmul( + descriptor, + workspace.data_ptr() if workspace is not None else None, + workspace_size.value, + c_tensor.data, + a_tensor.data, + b_tensor.data, + None, + ) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f" lib time: {elapsed :6f}") assert torch.allclose(c, ans, atol=0, rtol=1e-2) @@ -211,17 +245,11 @@ def test_bang(lib, test_cases): test_cases = [ # alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride, dtype (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None, torch.float16), - ( - 1.0, - 0.0, - (1, 2048), - (2048, 2048), - (1, 2048), - (4096, 1), - (4096, 1), - (4096, 1), - torch.float16, - ), + (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None, torch.float32), + (1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None, torch.float16), + (1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None, torch.float32), + (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1), torch.float16), + (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1), torch.float32), ] args = get_args() lib = open_lib() @@ -267,4 +295,4 @@ def test_bang(lib, test_cases): test_bang(lib, test_cases) if not (args.cpu or args.cuda or args.bang): test_cpu(lib, test_cases) - print("Test passed!") + print("\033[92mTest passed!\033[0m") diff --git a/src/devices/cpu/common_cpu.cc b/src/devices/cpu/common_cpu.cc index c89c7491..685b2a23 100644 --- a/src/devices/cpu/common_cpu.cc +++ b/src/devices/cpu/common_cpu.cc @@ -65,3 +65,12 @@ uint16_t f32_to_f16(float val) { return sign; } } + +uint64_t getDstIndex(uint64_t flat_index, uint64_t ndim, int64_t const *src_strides, int64_t const *dst_strides) { + uint64_t res = 0; + for (uint64_t i = 0; i < ndim; ++i) { + res += flat_index / src_strides[i] * dst_strides[i]; + flat_index %= src_strides[i]; + } + return res; +} diff --git a/src/devices/cpu/common_cpu.h b/src/devices/cpu/common_cpu.h index 20f1a2d8..f5c770ab 100644 --- a/src/devices/cpu/common_cpu.h +++ b/src/devices/cpu/common_cpu.h @@ -15,4 +15,7 @@ float f16_to_f32(uint16_t code); // convert single-precision float to half-precision float uint16_t f32_to_f16(float val); -#endif // __COMMON_CPU_H__ +// get the corresponding index in the destination given the flat index of the source +uint64_t getDstIndex(uint64_t flat_index, uint64_t ndim, int64_t const *src_strides, int64_t const *dst_strides); + +#endif// __COMMON_CPU_H__ diff --git a/src/devices/cuda/common_cuda.h b/src/devices/cuda/common_cuda.h index fa89e6c6..0c23aa68 100644 --- a/src/devices/cuda/common_cuda.h +++ b/src/devices/cuda/common_cuda.h @@ -54,4 +54,14 @@ typedef struct DataLayoutMap { constexpr DTMap dataTypeMap; +// get the corresponding index in the destination given the flat index of the source +inline __device__ uint64_t getDstIndex(uint64_t flat_index, uint64_t ndim, int64_t const *src_strides, int64_t const *dst_strides) { + uint64_t res = 0; + for (uint64_t i = 0; i < ndim; ++i) { + res += flat_index / src_strides[i] * dst_strides[i]; + flat_index %= src_strides[i]; + } + return res; +} + #endif// __COMMON_CUDA_H__ diff --git a/src/ops/add/cuda/add.cu b/src/ops/add/cuda/add.cu index 6c1dfec4..087db878 100644 --- a/src/ops/add/cuda/add.cu +++ b/src/ops/add/cuda/add.cu @@ -35,16 +35,6 @@ struct vecN { } }; -// get the corresponding index in the destination given the flat index of the source -__device__ uint64_t getDstIndex(uint64_t flat_index, uint64_t ndim, int64_t const *src_strides, int64_t const *dst_strides) { - uint64_t res = 0; - for (uint64_t i = 0; i < ndim; ++i) { - res += flat_index / src_strides[i] * dst_strides[i]; - flat_index %= src_strides[i]; - } - return res; -} - template __global__ void add( Tdata *c, diff --git a/src/ops/matmul/cpu/matmul_cpu.cc b/src/ops/matmul/cpu/matmul_cpu.cc index 88ced7a1..b6148852 100644 --- a/src/ops/matmul/cpu/matmul_cpu.cc +++ b/src/ops/matmul/cpu/matmul_cpu.cc @@ -12,7 +12,7 @@ infiniopStatus_t cpuCreateMatmulDescriptor(CpuHandle_t handle, float beta) { DT dtype = c_desc->dt; - if (!dtype_eq(dtype, F16)) { + if (dtype != F16 && dtype != F32) { return STATUS_BAD_TENSOR_DTYPE; } @@ -31,20 +31,6 @@ infiniopStatus_t cpuCreateMatmulDescriptor(CpuHandle_t handle, return STATUS_SUCCESS; } -infiniopStatus_t cpuMatmul(MatmulCpuDescriptor_t desc, - void *workspace, - uint64_t workspace_size, - void *c, - void const *a, - void const *b) { - if (dtype_eq(desc->dtype, F16)) { - matmul_cpu_f16(desc, c, desc->beta, a, b, desc->alpha); - return STATUS_SUCCESS; - } - - return STATUS_BAD_TENSOR_DTYPE; -} - infiniopStatus_t cpuGetMatmulWorkspaceSize(MatmulCpuDescriptor_t desc, uint64_t *size) { *size = 0; return STATUS_SUCCESS; @@ -55,7 +41,8 @@ infiniopStatus_t cpuDestroyMatmulDescriptor(MatmulCpuDescriptor_t desc) { return STATUS_SUCCESS; } -void matmul_cpu_f16(MatmulCpuDescriptor_t desc, void *c, float beta, void const *a, void const *b, float alpha) { +template +infiniopStatus_t matmul_cpu(MatmulCpuDescriptor_t desc, void *c, float beta, void const *a, void const *b, float alpha) { auto info = desc->info; if (info.is_transed) { @@ -65,15 +52,39 @@ void matmul_cpu_f16(MatmulCpuDescriptor_t desc, void *c, float beta, void const for (int i = 0; i < info.batch; ++i) { for (int m_ = 0; m_ < info.m; ++m_) { for (int n_ = 0; n_ < info.n; ++n_) { - auto c_ = reinterpret_cast(c) + i * info.c_matrix.stride + m_ * info.c_matrix.row_stride + n_ * info.c_matrix.col_stride; + auto c_ = reinterpret_cast(c) + i * info.c_matrix.stride + m_ * info.c_matrix.row_stride + n_ * info.c_matrix.col_stride; float sum = 0; for (int k_ = 0; k_ < info.k; ++k_) { - auto a_ = reinterpret_cast(a) + i * info.a_matrix.stride + m_ * info.a_matrix.row_stride + k_ * info.a_matrix.col_stride; - auto b_ = reinterpret_cast(b) + i * info.b_matrix.stride + n_ * info.b_matrix.col_stride + k_ * info.b_matrix.row_stride; - sum += f16_to_f32(*a_) * f16_to_f32(*b_); + auto a_ = reinterpret_cast(a) + i * info.a_matrix.stride + m_ * info.a_matrix.row_stride + k_ * info.a_matrix.col_stride; + auto b_ = reinterpret_cast(b) + i * info.b_matrix.stride + n_ * info.b_matrix.col_stride + k_ * info.b_matrix.row_stride; + if constexpr (std::is_same::value) { + sum += f16_to_f32(*a_) * f16_to_f32(*b_); + } else { + sum += *a_ * (*b_); + } + } + if constexpr (std::is_same::value) { + *c_ = f32_to_f16(beta * f16_to_f32(*c_) + alpha * sum); + } else { + *c_ = beta * (*c_) + alpha * sum; } - *c_ = f32_to_f16(beta * f16_to_f32(*c_) + alpha * sum); } } } + return STATUS_SUCCESS; +} + +infiniopStatus_t cpuMatmul(MatmulCpuDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *c, + void const *a, + void const *b) { + if (desc->dtype == F16) { + return matmul_cpu(desc, c, desc->beta, a, b, desc->alpha); + } + if (desc->dtype == F32) { + return matmul_cpu(desc, c, desc->beta, a, b, desc->alpha); + } + return STATUS_BAD_TENSOR_DTYPE; } diff --git a/src/ops/matmul/cpu/matmul_cpu.h b/src/ops/matmul/cpu/matmul_cpu.h index fcbd4c50..3a5970e8 100644 --- a/src/ops/matmul/cpu/matmul_cpu.h +++ b/src/ops/matmul/cpu/matmul_cpu.h @@ -34,6 +34,4 @@ infiniopStatus_t cpuMatmul(MatmulCpuDescriptor_t desc, infiniopStatus_t cpuDestroyMatmulDescriptor(MatmulCpuDescriptor_t desc); -void matmul_cpu_f16(MatmulCpuDescriptor_t desc, void *c, float beta, void const *a, void const *b, float alpha); - #endif// __CPU_MATMUL_H__ diff --git a/src/ops/matmul/cuda/matmul_cuda.cc b/src/ops/matmul/cuda/matmul_cuda.cc index 71f66cf6..8bac48d4 100644 --- a/src/ops/matmul/cuda/matmul_cuda.cc +++ b/src/ops/matmul/cuda/matmul_cuda.cc @@ -11,7 +11,7 @@ infiniopStatus_t cudaCreateMatmulDescriptor(CudaHandle_t handle, float beta) { DT dtype = c_desc->dt; - if (!dtype_eq(dtype, F16)) { + if (dtype != F16 && dtype != F32) { return STATUS_BAD_TENSOR_DTYPE; } @@ -32,21 +32,6 @@ infiniopStatus_t cudaCreateMatmulDescriptor(CudaHandle_t handle, return STATUS_SUCCESS; } -infiniopStatus_t cudaMatmul(MatmulCudaDescriptor_t desc, - void *workspace, - uint64_t workspace_size, - void *c, - void const *a, - void const *b, - void *stream) { - if (dtype_eq(desc->dtype, F16)) { - matmul_cuda_f16(desc, c, desc->beta, a, b, desc->alpha, stream); - return STATUS_SUCCESS; - } - - return STATUS_BAD_TENSOR_DTYPE; -} - infiniopStatus_t cudaGetMatmulWorkspaceSize(MatmulCudaDescriptor_t desc, uint64_t *size) { *size = 0; return STATUS_SUCCESS; diff --git a/src/ops/matmul/cuda/matmul_cuda.cu b/src/ops/matmul/cuda/matmul_cuda.cu index 32d0cf74..1dc93430 100644 --- a/src/ops/matmul/cuda/matmul_cuda.cu +++ b/src/ops/matmul/cuda/matmul_cuda.cu @@ -5,15 +5,29 @@ #include #include -void matmul_cuda_f16(MatmulCudaDescriptor_t desc, void *c, float beta, void const *a, void const *b, float alpha, void *stream) { +template +infiniopStatus_t matmul_cuda(MatmulCudaDescriptor_t desc, void *c, float beta, void const *a, void const *b, float alpha, void *stream) { auto info = desc->info; if (info.is_transed) { std::swap(a, b); } - auto alpha_f16 = __float2half(alpha); - auto beta_f16 = __float2half(beta); + Tdata alpha_, beta_; + cudaDataType a_type, b_type, c_type; + cublasComputeType_t compute_type; + + if constexpr (std::is_same::value) { + alpha_ = __float2half(alpha); + beta_ = __float2half(beta); + a_type = b_type = c_type = CUDA_R_16F; + compute_type = CUBLAS_COMPUTE_16F; + } else { + alpha_ = alpha; + beta_ = beta; + a_type = b_type = c_type = CUDA_R_32F; + compute_type = CUBLAS_COMPUTE_32F_FAST_16F; + } auto op_a = info.a_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T; auto op_b = info.b_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T; @@ -26,21 +40,39 @@ void matmul_cuda_f16(MatmulCudaDescriptor_t desc, void *c, float beta, void cons info.m, info.n, info.k, - &alpha_f16, + &alpha_, a, - CUDA_R_16F, + a_type, info.a_matrix.ld(), info.a_matrix.stride, b, - CUDA_R_16F, + b_type, info.b_matrix.ld(), info.b_matrix.stride, - &beta_f16, + &beta_, c, - CUDA_R_16F, + c_type, info.c_matrix.ld(), info.c_matrix.stride, info.batch, - CUBLAS_COMPUTE_16F, + compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP); }); + cudaDeviceSynchronize(); + return STATUS_SUCCESS; } + +infiniopStatus_t cudaMatmul(MatmulCudaDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *c, + void const *a, + void const *b, + void *stream) { + if (desc->dtype == F16) { + return matmul_cuda(desc, c, desc->beta, a, b, desc->alpha, stream); + } + if (desc->dtype == F32) { + return matmul_cuda(desc, c, desc->beta, a, b, desc->alpha, stream); + } + return STATUS_BAD_TENSOR_DTYPE; +} \ No newline at end of file diff --git a/src/ops/matmul/cuda/matmul_cuda.h b/src/ops/matmul/cuda/matmul_cuda.h index 671ac14c..f13531e8 100644 --- a/src/ops/matmul/cuda/matmul_cuda.h +++ b/src/ops/matmul/cuda/matmul_cuda.h @@ -1,8 +1,8 @@ #ifndef __CUDA_MATMUL_H__ #define __CUDA_MATMUL_H__ -#include "../blas.h" #include "../../../devices/cuda/cuda_handle.h" +#include "../blas.h" #include "operators.h" #include diff --git a/src/ops/utils.h b/src/ops/utils.h index fd2afcf0..ad2b65cc 100644 --- a/src/ops/utils.h +++ b/src/ops/utils.h @@ -101,6 +101,22 @@ inline bool isValidBroadcastShape(infiniopTensorDescriptor_t a, infiniopTensorDe return std::equal(broadcast_shape, broadcast_shape + broadcast_ndim, c->shape); } +// check if the shape of tensor src can be validly broadcasted to that of the tensor dst +inline bool isValidBroadcastShape(infiniopTensorDescriptor_t dst, infiniopTensorDescriptor_t src) { + if (dst->ndim < src->ndim) { + return false; + } + uint64_t padded_shape[dst->ndim]; + std::fill(padded_shape, padded_shape + dst->ndim, 1); + std::copy(src->shape, src->shape + src->ndim, padded_shape + dst->ndim - src->ndim); + for (size_t i = 0; i < dst->ndim; ++i) { + if (padded_shape[i] != dst->shape[i] && padded_shape[i] != 1) { + return false; + } + } + return true; +} + // check if the shape of tensor c is valid after broadcasting tensors a and b inline bool isValidBroadcastShape(infiniopTensorDescriptor_t a, infiniopTensorDescriptor_t b, infiniopTensorDescriptor_t c) { uint64_t broadcast_ndim = std::max(a->ndim, b->ndim); From ab31840217eefbd55b3e1115def0a19d035fbff4 Mon Sep 17 00:00:00 2001 From: lizimin Date: Thu, 31 Oct 2024 15:38:36 +0800 Subject: [PATCH 03/10] Add GEMM operator --- include/infini_operators.h | 2 + include/ops/gemm/gemm.h | 36 ++++ operatorspy/tests/gemm.py | 339 +++++++++++++++++++++++++++++++++++++ src/ops/gemm/operator.cc | 85 ++++++++++ 4 files changed, 462 insertions(+) create mode 100644 include/ops/gemm/gemm.h create mode 100644 operatorspy/tests/gemm.py create mode 100644 src/ops/gemm/operator.cc diff --git a/include/infini_operators.h b/include/infini_operators.h index ca076d79..5031d011 100644 --- a/include/infini_operators.h +++ b/include/infini_operators.h @@ -2,6 +2,8 @@ #include "ops/add/add.h" #include "ops/attention/attention.h" #include "ops/causal_softmax/causal_softmax.h" +#include "ops/expand/expand.h" +#include "ops/gemm/gemm.h" #include "ops/matmul/matmul.h" #include "ops/mlp/mlp.h" #include "ops/random_sample/random_sample.h" diff --git a/include/ops/gemm/gemm.h b/include/ops/gemm/gemm.h new file mode 100644 index 00000000..4a39da39 --- /dev/null +++ b/include/ops/gemm/gemm.h @@ -0,0 +1,36 @@ +#ifndef GEMM_H +#define GEMM_H + +#include "../../export.h" +#include "../../operators.h" + +typedef struct GEMMDescriptor { + Device device; +} GEMMDescriptor; + +typedef GEMMDescriptor *infiniopGEMMDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateGEMMDescriptor(infiniopHandle_t handle, + infiniopGEMMDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t c_desc, + float alpha, + float beta, + bool transA, + bool transB); + +__C __export infiniopStatus_t infiniopGetGEMMWorkspaceSize(infiniopGEMMDescriptor_t desc, uint64_t *size); + +__C __export infiniopStatus_t infiniopGEMM(infiniopGEMMDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *y, + void const *a, + void const *b, + void const *c, + void *stream); + +__C __export infiniopStatus_t infiniopDestroyGEMMDescriptor(infiniopGEMMDescriptor_t desc); +#endif diff --git a/operatorspy/tests/gemm.py b/operatorspy/tests/gemm.py new file mode 100644 index 00000000..402a3d9b --- /dev/null +++ b/operatorspy/tests/gemm.py @@ -0,0 +1,339 @@ +from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float, c_bool +import ctypes +import sys +import os +import time + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from operatorspy import ( + open_lib, + to_tensor, + DeviceEnum, + infiniopHandle_t, + infiniopTensorDescriptor_t, + create_handle, + destroy_handle, + check_error, + rearrange_tensor, +) + +from operatorspy.tests.test_utils import get_args +import torch + +# constant for control whether profile the pytorch and lib functions +# NOTE: need to manually add synchronization function to the lib function, +# e.g., cudaDeviceSynchronize() for CUDA +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + +class GEMMDescriptor(Structure): + _fields_ = [("device", c_int32)] + + +infiniopGEMMDescriptor_t = POINTER(GEMMDescriptor) + + +def gemm(A, B, C=None, transA=False, transB=False, alpha=1.0, beta=0.0, dtype=torch.float32): + A = A.T if transA else A + B = B.T if transB else B + result = alpha * torch.matmul(A if dtype != torch.float16 else A.to(torch.float32), B if dtype != torch.float16 else B.to(torch.float32)).to(dtype) + if C is not None: + result += beta * C if dtype != torch.float16 else C.to(torch.float32) + if PROFILE: + torch.cuda.synchronize() + return result + + +def test( + lib, + handle, + torch_device, + alpha, + beta, + transA, + transB, + a_shape, + b_shape, + c_shape, + y_shape, + a_stride=None, + b_stride=None, + c_stride=None, + y_stride=None, + dtype=torch.float16, +): + print( + f"Testing GEMM on {torch_device} with transA: {transA} transB: {transB} " + f"a_shape:{a_shape} b_shape:{b_shape} c_shape:{c_shape} y_shape:{y_shape} " + f"a_stride:{a_stride} b_stride:{b_stride} c_stride:{c_stride} y_stride:{y_stride} dtype:{dtype}" + ) + + a = torch.ones(a_shape, dtype=dtype).to(torch_device) + b = torch.ones(b_shape, dtype=dtype).to(torch_device) + c = torch.ones(c_shape, dtype=dtype).to(torch_device) + y = torch.zeros(y_shape, dtype=dtype).to(torch_device) + + if a_stride is not None: + a = rearrange_tensor(a, a_stride) + if b_stride is not None: + b = rearrange_tensor(b, b_stride) + if c_stride is not None: + c = rearrange_tensor(c, c_stride) + if y_stride is not None: + y = rearrange_tensor(y, y_stride) + + for i in range(NUM_PRERUN if PROFILE else 1): + ans = gemm(a, b, c, transA, transB, alpha, beta, dtype) + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + _ = gemm(a, b, c, transA, transB, alpha, beta, dtype) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f"pytorch time: {elapsed :6f}") + + + a_tensor = to_tensor(a, lib) + b_tensor = to_tensor(b, lib) + c_tensor = to_tensor(c, lib) + y_tensor = to_tensor(y, lib) + descriptor = infiniopGEMMDescriptor_t() + check_error( + lib.infiniopCreateGEMMDescriptor( + handle, + ctypes.byref(descriptor), + y_tensor.descriptor, + a_tensor.descriptor, + b_tensor.descriptor, + c_tensor.descriptor, + alpha, + beta, + transA, + transB, + ) + ) + + workspace_size = ctypes.c_uint64(0) + check_error( + lib.infiniopGetGEMMWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = torch.zeros(int(workspace_size.value), dtype=torch.uint8).to( + torch_device + ) + workspace_ptr = ctypes.cast(workspace.data_ptr(), ctypes.POINTER(ctypes.c_uint8)) + + for i in range(NUM_PRERUN if PROFILE else 2): + check_error( + lib.infiniopGEMM( + descriptor, + workspace_ptr, + workspace_size, + y_tensor.data, + a_tensor.data, + b_tensor.data, + c_tensor.data, + None, + ) + ) + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + lib.infiniopGEMM( + descriptor, + workspace_ptr, + workspace_size, + y_tensor.data, + a_tensor.data, + b_tensor.data, + c_tensor.data, + None, + ) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f" lib time: {elapsed :6f}") + + # print(" - y:\n", y, y.shape, "\n - ans:\n", ans, ans.shape) + assert torch.allclose(y, ans, atol=0, rtol=1e-2) + check_error(lib.infiniopDestroyGEMMDescriptor(descriptor)) + + +def test_cpu(lib, test_cases): + device = DeviceEnum.DEVICE_CPU + handle = create_handle(lib, device) + for ( + alpha, + beta, + transA, + transB, + a_shape, + b_shape, + c_shape, + y_shape, + a_stride, + b_stride, + c_stride, + y_stride, + ) in test_cases: + test(lib, handle, "cpu", alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride, dtype=torch.float16) + test(lib, handle, "cpu", alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride, dtype=torch.float32) + destroy_handle(lib, handle) + + +def test_cuda(lib, test_cases): + device = DeviceEnum.DEVICE_CUDA + handle = create_handle(lib, device) + for ( + alpha, + beta, + transA, + transB, + a_shape, + b_shape, + c_shape, + y_shape, + a_stride, + b_stride, + c_stride, + y_stride, + ) in test_cases: + test(lib, handle, "cuda", alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride, dtype=torch.float16) + test(lib, handle, "cuda", alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride, dtype=torch.float32) + destroy_handle(lib, handle) + + +def test_bang(lib, test_cases): + import torch_mlu + + device = DeviceEnum.DEVICE_BANG + handle = create_handle(lib, device) + + for ( + alpha, + beta, + transA, + transB, + a_shape, + b_shape, + c_shape, + y_shape, + a_stride, + b_stride, + c_stride, + y_stride, + ) in test_cases: + test(lib, handle, "mlu", alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride, dtype=torch.float16) + test(lib, handle, "mlu", alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride, dtype=torch.float32) + + destroy_handle(lib, handle) + + +if __name__ == "__main__": + test_cases = [ + # alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride + ( + 1.0, + 1.0, + False, + False, + (1, 2048), + (2048, 2048), + (1, 2048), + (1, 2048), + None, + None, + None, + None, + ), + ( + 1.0, + 1.0, + True, + True, + (2048, 4), + (2048, 2048), + (4, 2048), + (4, 2048), + None, + None, + None, + None, + ), + ( + 1.0, + 1.0, + False, + True, + (1, 2048), + (1000, 2048), + (1000), + (1, 1000), + None, + None, + None, + None, + ), + # ( + # 1.0, + # 1.0, + # True, + # False, + # (2048, 4), + # (2048, 2048), + # (4, 2048), + # (4, 2048), + # (4096, 1), + # (4096, 1), + # (4096, 1), + # (4096, 1), + # ), + ] + args = get_args() + lib = open_lib() + + lib.infiniopCreateGEMMDescriptor.restype = c_int32 + lib.infiniopCreateGEMMDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopGEMMDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_float, + c_float, + c_bool, + c_bool, + ] + + lib.infiniopGetGEMMWorkspaceSize.restype = c_int32 + lib.infiniopGetGEMMWorkspaceSize.argtypes = [ + infiniopGEMMDescriptor_t, + POINTER(c_uint64), + ] + + lib.infiniopGEMM.restype = c_int32 + lib.infiniopGEMM.argtypes = [ + infiniopGEMMDescriptor_t, + c_void_p, + c_uint64, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyGEMMDescriptor.restype = c_int32 + lib.infiniopDestroyGEMMDescriptor.argtypes = [ + infiniopGEMMDescriptor_t, + ] + + if args.cpu: + test_cpu(lib, test_cases) + if args.cuda: + test_cuda(lib, test_cases) + if args.bang: + test_bang(lib, test_cases) + if not (args.cpu or args.cuda or args.bang): + test_cpu(lib, test_cases) + print("\033[92mTest passed!\033[0m") diff --git a/src/ops/gemm/operator.cc b/src/ops/gemm/operator.cc new file mode 100644 index 00000000..d22464f1 --- /dev/null +++ b/src/ops/gemm/operator.cc @@ -0,0 +1,85 @@ +#include "../utils.h" +#include "ops/expand/expand.h" +#include "ops/gemm/gemm.h" +#include "ops/matmul/matmul.h" +#include "tensor/tensor_descriptor.h" + +struct _GEMMDescriptor { + Device device; + infiniopMatmulDescriptor_t matmul_desc; + infiniopExpandDescriptor_t expand_desc; + uint64_t workspace_size; +}; + +typedef struct _GEMMDescriptor *_GEMMDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateGEMMDescriptor(infiniopHandle_t handle, + infiniopGEMMDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t c_desc, + float alpha, + float beta, + bool transA, + bool transB) { + // transpose a and b if needed + a_desc = transA ? permute(a_desc, {1, 0}) : a_desc; + b_desc = transB ? permute(b_desc, {1, 0}) : b_desc; + + // expand desc + infiniopExpandDescriptor_t expand_desc = new ExpandDescriptor{handle->device}; + CHECK_STATUS(infiniopCreateExpandDescriptor(handle, &expand_desc, y_desc, c_desc), STATUS_SUCCESS); + + // matmul desc + infiniopMatmulDescriptor_t matmul_desc = new MatmulDescriptor{handle->device}; + CHECK_STATUS(infiniopCreateMatmulDescriptor(handle, &matmul_desc, y_desc, alpha, a_desc, b_desc, beta), STATUS_SUCCESS); + uint64_t workspace_size = 0; + CHECK_STATUS(infiniopGetMatmulWorkspaceSize(matmul_desc, &workspace_size), STATUS_SUCCESS); + + *(_GEMMDescriptor_t *) desc_ptr = new _GEMMDescriptor{ + handle->device, + matmul_desc, + expand_desc, + workspace_size, + }; + + return STATUS_SUCCESS; +} + +__C __export infiniopStatus_t infiniopGetGEMMWorkspaceSize(infiniopGEMMDescriptor_t desc, uint64_t *size) { + *size = ((_GEMMDescriptor_t) desc)->workspace_size; + return STATUS_SUCCESS; +} + +__C __export infiniopStatus_t infiniopGEMM(infiniopGEMMDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *y, + void const *a, + void const *b, + void const *c, + void *stream) { + auto _desc = (_GEMMDescriptor_t) desc; + if (workspace_size < _desc->workspace_size) { + return STATUS_MEMORY_NOT_ALLOCATED; + } + + CHECK_STATUS(infiniopExpand(_desc->expand_desc, + y, c, stream), + STATUS_SUCCESS); + + CHECK_STATUS(infiniopMatmul(_desc->matmul_desc, + workspace, + workspace_size, + y, a, b, stream), + STATUS_SUCCESS); + + return STATUS_SUCCESS; +} + +__C __export infiniopStatus_t infiniopDestroyGEMMDescriptor(infiniopGEMMDescriptor_t desc) { + CHECK_STATUS(infiniopDestroyMatmulDescriptor(((_GEMMDescriptor_t) desc)->matmul_desc), STATUS_SUCCESS); + CHECK_STATUS(infiniopDestroyExpandDescriptor(((_GEMMDescriptor_t) desc)->expand_desc), STATUS_SUCCESS); + return STATUS_SUCCESS; +} From 3203a8328e85f4b2101e273ee7954bdc2b8d7d68 Mon Sep 17 00:00:00 2001 From: lizimin Date: Thu, 31 Oct 2024 17:49:11 +0800 Subject: [PATCH 04/10] Allow Expand opeartor to handle noncontiguous data --- operatorspy/tests/expand.py | 1 - operatorspy/tests/gemm.py | 38 +++++++++++++++---------------- src/devices/cpu/common_cpu.cc | 9 ++++++++ src/devices/cpu/common_cpu.h | 3 +++ src/devices/cuda/common_cuda.h | 10 ++++++++ src/ops/expand/cuda/expand.cc | 5 ++++ src/ops/expand/cuda/expand.cu | 6 +++-- src/ops/expand/cuda/expand.cuh | 1 + src/ops/matmul/cuda/matmul_cuda.h | 2 -- 9 files changed, 51 insertions(+), 24 deletions(-) diff --git a/operatorspy/tests/expand.py b/operatorspy/tests/expand.py index fea84d19..c8f2399d 100644 --- a/operatorspy/tests/expand.py +++ b/operatorspy/tests/expand.py @@ -136,7 +136,6 @@ def test_bang(lib, test_cases): test_cases = [ # y_shape, x_shape, y_stride, x_stride ((), (), None, None), - # ((4, 2048), (2048,), (4096, 1), (1,)), ((3, 3), (1,), None, None), ((5, 4, 3), (4, 3,), None, (6, 1)), ((99, 111), (111,), None, None), diff --git a/operatorspy/tests/gemm.py b/operatorspy/tests/gemm.py index 402a3d9b..1b4ace6b 100644 --- a/operatorspy/tests/gemm.py +++ b/operatorspy/tests/gemm.py @@ -69,10 +69,10 @@ def test( f"a_stride:{a_stride} b_stride:{b_stride} c_stride:{c_stride} y_stride:{y_stride} dtype:{dtype}" ) - a = torch.ones(a_shape, dtype=dtype).to(torch_device) - b = torch.ones(b_shape, dtype=dtype).to(torch_device) - c = torch.ones(c_shape, dtype=dtype).to(torch_device) - y = torch.zeros(y_shape, dtype=dtype).to(torch_device) + a = torch.rand(a_shape, dtype=dtype).to(torch_device) + b = torch.rand(b_shape, dtype=dtype).to(torch_device) + c = torch.rand(c_shape, dtype=dtype).to(torch_device) + y = torch.rand(y_shape, dtype=dtype).to(torch_device) if a_stride is not None: a = rearrange_tensor(a, a_stride) @@ -124,7 +124,7 @@ def test( ) workspace_ptr = ctypes.cast(workspace.data_ptr(), ctypes.POINTER(ctypes.c_uint8)) - for i in range(NUM_PRERUN if PROFILE else 2): + for i in range(NUM_PRERUN if PROFILE else 1): check_error( lib.infiniopGEMM( descriptor, @@ -273,20 +273,20 @@ def test_bang(lib, test_cases): None, None, ), - # ( - # 1.0, - # 1.0, - # True, - # False, - # (2048, 4), - # (2048, 2048), - # (4, 2048), - # (4, 2048), - # (4096, 1), - # (4096, 1), - # (4096, 1), - # (4096, 1), - # ), + ( + 1.0, + 1.0, + True, + False, + (2048, 4), + (2048, 2048), + (2048), + (4, 2048), + (4096, 1), + (4096, 1), + (2,), + (4096, 1), + ), ] args = get_args() lib = open_lib() diff --git a/src/devices/cpu/common_cpu.cc b/src/devices/cpu/common_cpu.cc index 685b2a23..cd27e0b7 100644 --- a/src/devices/cpu/common_cpu.cc +++ b/src/devices/cpu/common_cpu.cc @@ -74,3 +74,12 @@ uint64_t getDstIndex(uint64_t flat_index, uint64_t ndim, int64_t const *src_stri } return res; } + +uint64_t getNextIndex(uint64_t flat_index, uint64_t ndim, uint64_t const *shape, int64_t const *strides) { + uint64_t res = 0; + for (long i = ndim - 1; i >= 0; --i) { + res += (flat_index % shape[i]) * strides[i]; + flat_index /= shape[i]; + } + return res; +} diff --git a/src/devices/cpu/common_cpu.h b/src/devices/cpu/common_cpu.h index f5c770ab..9ae12847 100644 --- a/src/devices/cpu/common_cpu.h +++ b/src/devices/cpu/common_cpu.h @@ -18,4 +18,7 @@ uint16_t f32_to_f16(float val); // get the corresponding index in the destination given the flat index of the source uint64_t getDstIndex(uint64_t flat_index, uint64_t ndim, int64_t const *src_strides, int64_t const *dst_strides); +// get the offset of the next element in a tensor given its flat index +uint64_t getNextIndex(uint64_t flat_index, uint64_t ndim, uint64_t const *shape, int64_t const *strides); + #endif// __COMMON_CPU_H__ diff --git a/src/devices/cuda/common_cuda.h b/src/devices/cuda/common_cuda.h index 0c23aa68..fb7bc598 100644 --- a/src/devices/cuda/common_cuda.h +++ b/src/devices/cuda/common_cuda.h @@ -64,4 +64,14 @@ inline __device__ uint64_t getDstIndex(uint64_t flat_index, uint64_t ndim, int64 return res; } +// get the offset of the next element in a tensor given its flat index +inline __device__ uint64_t getNextIndex(uint64_t flat_index, uint64_t ndim, uint64_t const *shape, int64_t const *strides) { + uint64_t res = 0; + for (long i = ndim - 1; i >= 0; --i) { + res += (flat_index % shape[i]) * strides[i]; + flat_index /= shape[i]; + } + return res; +} + #endif// __COMMON_CUDA_H__ diff --git a/src/ops/expand/cuda/expand.cc b/src/ops/expand/cuda/expand.cc index bd21b34c..deb171b0 100644 --- a/src/ops/expand/cuda/expand.cc +++ b/src/ops/expand/cuda/expand.cc @@ -26,10 +26,13 @@ infiniopStatus_t cudaCreateExpandDescriptor(CudaHandle_t handle, cudaGetDeviceProperties(&prop, handle->device_id); int64_t *x_strides_d, *y_strides_d; + uint64_t *y_shape_d; checkCudaErrorWithCode(cudaMalloc(&x_strides_d, ndim * sizeof(int64_t)), STATUS_MEMORY_NOT_ALLOCATED); checkCudaErrorWithCode(cudaMalloc(&y_strides_d, ndim * sizeof(int64_t)), STATUS_MEMORY_NOT_ALLOCATED); + checkCudaErrorWithCode(cudaMalloc(&y_shape_d, ndim * sizeof(uint64_t)), STATUS_MEMORY_NOT_ALLOCATED); checkCudaErrorWithCode(cudaMemcpy(x_strides_d, x_strides, ndim * sizeof(int64_t), cudaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); checkCudaErrorWithCode(cudaMemcpy(y_strides_d, y->strides, ndim * sizeof(int64_t), cudaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaMemcpy(y_shape_d, y->shape, ndim * sizeof(uint64_t), cudaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); *desc_ptr = new ExpandCudaDescriptor{ DevNvGpu, @@ -38,6 +41,7 @@ infiniopStatus_t cudaCreateExpandDescriptor(CudaHandle_t handle, ndim, y_data_size, static_cast(prop.maxGridSize[0]), + y_shape_d, x_strides_d, y_strides_d, }; @@ -50,6 +54,7 @@ infiniopStatus_t cudaCreateExpandDescriptor(CudaHandle_t handle, infiniopStatus_t cudaDestroyExpandDescriptor(ExpandCudaDescriptor_t desc) { cudaFree((void *) desc->x_strides); cudaFree((void *) desc->y_strides); + cudaFree((void *) desc->y_shape); delete desc; return STATUS_SUCCESS; } diff --git a/src/ops/expand/cuda/expand.cu b/src/ops/expand/cuda/expand.cu index a879fb20..6d64a75a 100644 --- a/src/ops/expand/cuda/expand.cu +++ b/src/ops/expand/cuda/expand.cu @@ -8,13 +8,15 @@ __global__ void expand( const Tdata *x, const int64_t *y_strides, const int64_t *x_strides, + const uint64_t *y_shape, uint64_t y_data_size, uint64_t ndim, uint64_t offset) { uint64_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset; if (idx < y_data_size) { - y[idx] = x[getDstIndex(idx, ndim, y_strides, x_strides)]; + uint64_t y_idx = getNextIndex(idx, ndim, y_shape, y_strides); + y[y_idx] = x[getDstIndex(y_idx, ndim, y_strides, x_strides)]; } } @@ -34,7 +36,7 @@ infiniopStatus_t expand_nv_gpu(ExpandCudaDescriptor_t desc, void *y, void const #pragma unroll for (uint64_t i = 0; i < desc->y_data_size; i += step) { expand<<>>( - y_, x_, desc->y_strides, desc->x_strides, i + desc->y_data_size, desc->ndim, i); + y_, x_, desc->y_strides, desc->x_strides, desc->y_shape, i + desc->y_data_size, desc->ndim, i); } return STATUS_SUCCESS; } diff --git a/src/ops/expand/cuda/expand.cuh b/src/ops/expand/cuda/expand.cuh index 2f18a82f..0764243a 100644 --- a/src/ops/expand/cuda/expand.cuh +++ b/src/ops/expand/cuda/expand.cuh @@ -14,6 +14,7 @@ struct ExpandCudaDescriptor { uint64_t ndim; uint64_t y_data_size; uint64_t max_grid_size; + uint64_t const *y_shape; int64_t const *x_strides; int64_t const *y_strides; }; diff --git a/src/ops/matmul/cuda/matmul_cuda.h b/src/ops/matmul/cuda/matmul_cuda.h index f13531e8..3e82c1ed 100644 --- a/src/ops/matmul/cuda/matmul_cuda.h +++ b/src/ops/matmul/cuda/matmul_cuda.h @@ -38,6 +38,4 @@ infiniopStatus_t cudaMatmul(MatmulCudaDescriptor_t desc, infiniopStatus_t cudaDestroyMatmulDescriptor(MatmulCudaDescriptor_t desc); -void matmul_cuda_f16(MatmulCudaDescriptor_t desc, void *c, float beta, void const *a, void const *b, float alpha, void *stream); - #endif// __CUDA_MATMUL_H__ From 0e9375301bca0acf1d64d329d1f1bb4e34d5f200 Mon Sep 17 00:00:00 2001 From: lizimin Date: Thu, 31 Oct 2024 17:56:41 +0800 Subject: [PATCH 05/10] Add 3D GEMM test case --- operatorspy/tests/gemm.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/operatorspy/tests/gemm.py b/operatorspy/tests/gemm.py index 1b4ace6b..f7da6a11 100644 --- a/operatorspy/tests/gemm.py +++ b/operatorspy/tests/gemm.py @@ -287,6 +287,20 @@ def test_bang(lib, test_cases): (2,), (4096, 1), ), + ( + 1.0, + 1.0, + False, + False, + (3, 1, 2048), + (3, 2048, 2048), + (1,), + (3, 1, 2048), + None, + None, + None, + None, + ), ] args = get_args() lib = open_lib() From d7365b50c22d57aca2c5819dcc1183bd2d165369 Mon Sep 17 00:00:00 2001 From: lizimin Date: Fri, 1 Nov 2024 10:08:28 +0800 Subject: [PATCH 06/10] Remove cudaDeviceSynchronize() in matmul.cu --- src/ops/matmul/cuda/matmul_cuda.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/src/ops/matmul/cuda/matmul_cuda.cu b/src/ops/matmul/cuda/matmul_cuda.cu index 1dc93430..b1f00726 100644 --- a/src/ops/matmul/cuda/matmul_cuda.cu +++ b/src/ops/matmul/cuda/matmul_cuda.cu @@ -57,7 +57,6 @@ infiniopStatus_t matmul_cuda(MatmulCudaDescriptor_t desc, void *c, float beta, v info.batch, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP); }); - cudaDeviceSynchronize(); return STATUS_SUCCESS; } From ea8dbeac3d5010ac868b461f0b53d3c1733786c3 Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Fri, 1 Nov 2024 18:34:30 +0800 Subject: [PATCH 07/10] Change util function names --- src/devices/cpu/common_cpu.cc | 4 ++-- src/devices/cpu/common_cpu.h | 8 ++++---- src/devices/cuda/common_cuda.h | 8 ++++---- src/ops/add/cuda/add.cu | 4 ++-- src/ops/expand/cpu/expand_cpu.cc | 3 +-- src/ops/expand/cuda/expand.cu | 4 ++-- 6 files changed, 15 insertions(+), 16 deletions(-) diff --git a/src/devices/cpu/common_cpu.cc b/src/devices/cpu/common_cpu.cc index cd27e0b7..b5b5f0fd 100644 --- a/src/devices/cpu/common_cpu.cc +++ b/src/devices/cpu/common_cpu.cc @@ -66,7 +66,7 @@ uint16_t f32_to_f16(float val) { } } -uint64_t getDstIndex(uint64_t flat_index, uint64_t ndim, int64_t const *src_strides, int64_t const *dst_strides) { +uint64_t getDstOffset(uint64_t flat_index, uint64_t ndim, int64_t const *src_strides, int64_t const *dst_strides) { uint64_t res = 0; for (uint64_t i = 0; i < ndim; ++i) { res += flat_index / src_strides[i] * dst_strides[i]; @@ -75,7 +75,7 @@ uint64_t getDstIndex(uint64_t flat_index, uint64_t ndim, int64_t const *src_stri return res; } -uint64_t getNextIndex(uint64_t flat_index, uint64_t ndim, uint64_t const *shape, int64_t const *strides) { +uint64_t getOffset(uint64_t flat_index, uint64_t ndim, uint64_t const *shape, int64_t const *strides) { uint64_t res = 0; for (long i = ndim - 1; i >= 0; --i) { res += (flat_index % shape[i]) * strides[i]; diff --git a/src/devices/cpu/common_cpu.h b/src/devices/cpu/common_cpu.h index 9ae12847..caf3dd73 100644 --- a/src/devices/cpu/common_cpu.h +++ b/src/devices/cpu/common_cpu.h @@ -15,10 +15,10 @@ float f16_to_f32(uint16_t code); // convert single-precision float to half-precision float uint16_t f32_to_f16(float val); -// get the corresponding index in the destination given the flat index of the source -uint64_t getDstIndex(uint64_t flat_index, uint64_t ndim, int64_t const *src_strides, int64_t const *dst_strides); +// get the corresponding offset in the destination given the flat index of the source (for element mapping in shape broadcast) +uint64_t getDstOffset(uint64_t flat_index, uint64_t ndim, int64_t const *src_strides, int64_t const *dst_strides); -// get the offset of the next element in a tensor given its flat index -uint64_t getNextIndex(uint64_t flat_index, uint64_t ndim, uint64_t const *shape, int64_t const *strides); +// get the memory offset of the given element in a tensor given its flat index +uint64_t getOffset(uint64_t flat_index, uint64_t ndim, uint64_t const *shape, int64_t const *strides); #endif// __COMMON_CPU_H__ diff --git a/src/devices/cuda/common_cuda.h b/src/devices/cuda/common_cuda.h index fb7bc598..3bd7e856 100644 --- a/src/devices/cuda/common_cuda.h +++ b/src/devices/cuda/common_cuda.h @@ -54,8 +54,8 @@ typedef struct DataLayoutMap { constexpr DTMap dataTypeMap; -// get the corresponding index in the destination given the flat index of the source -inline __device__ uint64_t getDstIndex(uint64_t flat_index, uint64_t ndim, int64_t const *src_strides, int64_t const *dst_strides) { +// get the corresponding offset in the destination given the flat index of the source (for element mapping in shape broadcast) +inline __device__ uint64_t getDstOffset(uint64_t flat_index, uint64_t ndim, int64_t const *src_strides, int64_t const *dst_strides) { uint64_t res = 0; for (uint64_t i = 0; i < ndim; ++i) { res += flat_index / src_strides[i] * dst_strides[i]; @@ -64,8 +64,8 @@ inline __device__ uint64_t getDstIndex(uint64_t flat_index, uint64_t ndim, int64 return res; } -// get the offset of the next element in a tensor given its flat index -inline __device__ uint64_t getNextIndex(uint64_t flat_index, uint64_t ndim, uint64_t const *shape, int64_t const *strides) { +// get the memory offset of the given element in a tensor given its flat index +inline __device__ uint64_t getOffset(uint64_t flat_index, uint64_t ndim, uint64_t const *shape, int64_t const *strides) { uint64_t res = 0; for (long i = ndim - 1; i >= 0; --i) { res += (flat_index % shape[i]) * strides[i]; diff --git a/src/ops/add/cuda/add.cu b/src/ops/add/cuda/add.cu index 087db878..9d9aefcb 100644 --- a/src/ops/add/cuda/add.cu +++ b/src/ops/add/cuda/add.cu @@ -58,8 +58,8 @@ __global__ void add( auto c_ = reinterpret_cast(c); #pragma unroll for (size_t i = 0; i < pack_size; ++i) { - auto a_idx = getDstIndex(idx + i, ndim, c_strides, a_strides); - auto b_idx = getDstIndex(idx + i, ndim, c_strides, b_strides); + auto a_idx = getDstOffset(idx + i, ndim, c_strides, a_strides); + auto b_idx = getDstOffset(idx + i, ndim, c_strides, b_strides); c_[idx + i] = a_[a_idx] + b_[b_idx]; } return; diff --git a/src/ops/expand/cpu/expand_cpu.cc b/src/ops/expand/cpu/expand_cpu.cc index b5fe2698..19c2c074 100644 --- a/src/ops/expand/cpu/expand_cpu.cc +++ b/src/ops/expand/cpu/expand_cpu.cc @@ -1,7 +1,6 @@ #include "expand_cpu.h" #include "../../../devices/cpu/common_cpu.h" #include "../../utils.h" -#include infiniopStatus_t cpuCreateExpandDescriptor(infiniopHandle_t, ExpandCpuDescriptor_t *desc_ptr, @@ -49,7 +48,7 @@ infiniopStatus_t expand_cpu(ExpandCpuDescriptor_t desc, void *y, void const *x) #pragma omp parallel for for (uint64_t i = 0; i < desc->y_data_size; ++i) { - y_[i] = x_[getDstIndex(i, desc->ndim, desc->y_strides, desc->x_strides)]; + y_[i] = x_[getDstOffset(i, desc->ndim, desc->y_strides, desc->x_strides)]; } return STATUS_SUCCESS; } diff --git a/src/ops/expand/cuda/expand.cu b/src/ops/expand/cuda/expand.cu index 6d64a75a..d307e4d1 100644 --- a/src/ops/expand/cuda/expand.cu +++ b/src/ops/expand/cuda/expand.cu @@ -15,8 +15,8 @@ __global__ void expand( uint64_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset; if (idx < y_data_size) { - uint64_t y_idx = getNextIndex(idx, ndim, y_shape, y_strides); - y[y_idx] = x[getDstIndex(y_idx, ndim, y_strides, x_strides)]; + uint64_t y_idx = getOffset(idx, ndim, y_shape, y_strides); + y[y_idx] = x[getDstOffset(y_idx, ndim, y_strides, x_strides)]; } } From 1fe5d07fabc203a4193acf811779faabfd4d81bd Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Mon, 4 Nov 2024 20:02:09 +0800 Subject: [PATCH 08/10] Make c tensor optional, change GEMM CUDA fp32 compute type, merge cudaMalloc in expand into one, etc. --- operatorspy/tests/expand.py | 2 -- operatorspy/tests/gemm.py | 27 ++++++++++++++++++++------- src/ops/expand/cpu/expand_cpu.cc | 5 ++++- src/ops/expand/cpu/expand_cpu.h | 1 + src/ops/expand/cuda/expand.cc | 20 +++++++------------- src/ops/expand/cuda/expand.cu | 5 ++++- src/ops/expand/cuda/expand.cuh | 4 +--- src/ops/gemm/operator.cc | 23 +++++++++++++++++------ src/ops/matmul/cuda/matmul_cuda.cu | 4 ++-- 9 files changed, 56 insertions(+), 35 deletions(-) diff --git a/operatorspy/tests/expand.py b/operatorspy/tests/expand.py index c8f2399d..15b3909d 100644 --- a/operatorspy/tests/expand.py +++ b/operatorspy/tests/expand.py @@ -144,8 +144,6 @@ def test_bang(lib, test_cases): ((2, 3, 4, 5), (5,), None, None), ((3, 2, 4, 5), (3, 2, 1, 1), None, None), ((32, 256, 112, 112), (32, 256, 112, 1), None, None), - # ((32, 256, 112, 112), (32, 1, 1, 1), None, None), - # ((32, 150, 51200), (32, 150, 1), None, None), ] args = get_args() lib = open_lib() diff --git a/operatorspy/tests/gemm.py b/operatorspy/tests/gemm.py index f7da6a11..3fce2394 100644 --- a/operatorspy/tests/gemm.py +++ b/operatorspy/tests/gemm.py @@ -71,14 +71,14 @@ def test( a = torch.rand(a_shape, dtype=dtype).to(torch_device) b = torch.rand(b_shape, dtype=dtype).to(torch_device) - c = torch.rand(c_shape, dtype=dtype).to(torch_device) + c = torch.rand(c_shape, dtype=dtype).to(torch_device) if c_shape else None y = torch.rand(y_shape, dtype=dtype).to(torch_device) if a_stride is not None: a = rearrange_tensor(a, a_stride) if b_stride is not None: b = rearrange_tensor(b, b_stride) - if c_stride is not None: + if c_stride is not None and c is not None: c = rearrange_tensor(c, c_stride) if y_stride is not None: y = rearrange_tensor(y, y_stride) @@ -95,7 +95,7 @@ def test( a_tensor = to_tensor(a, lib) b_tensor = to_tensor(b, lib) - c_tensor = to_tensor(c, lib) + c_tensor = to_tensor(c, lib) if c is not None else None y_tensor = to_tensor(y, lib) descriptor = infiniopGEMMDescriptor_t() check_error( @@ -105,7 +105,7 @@ def test( y_tensor.descriptor, a_tensor.descriptor, b_tensor.descriptor, - c_tensor.descriptor, + c_tensor.descriptor if c_tensor else None, alpha, beta, transA, @@ -133,7 +133,7 @@ def test( y_tensor.data, a_tensor.data, b_tensor.data, - c_tensor.data, + c_tensor.data if c_tensor else None, None, ) ) @@ -147,13 +147,12 @@ def test( y_tensor.data, a_tensor.data, b_tensor.data, - c_tensor.data, + c_tensor.data if c_tensor else None, None, ) elapsed = (time.time() - start_time) / NUM_ITERATIONS print(f" lib time: {elapsed :6f}") - # print(" - y:\n", y, y.shape, "\n - ans:\n", ans, ans.shape) assert torch.allclose(y, ans, atol=0, rtol=1e-2) check_error(lib.infiniopDestroyGEMMDescriptor(descriptor)) @@ -301,6 +300,20 @@ def test_bang(lib, test_cases): None, None, ), + ( + 1.0, + 1.0, + True, + False, + (2048, 4), + (2048, 2048), + None, + (4, 2048), + (4096, 1), + (4096, 1), + (2,), + (4096, 1), + ), ] args = get_args() lib = open_lib() diff --git a/src/ops/expand/cpu/expand_cpu.cc b/src/ops/expand/cpu/expand_cpu.cc index 19c2c074..d3bcb866 100644 --- a/src/ops/expand/cpu/expand_cpu.cc +++ b/src/ops/expand/cpu/expand_cpu.cc @@ -18,10 +18,12 @@ infiniopStatus_t cpuCreateExpandDescriptor(infiniopHandle_t, // get the adjusted strides for x in terms of y int64_t *x_strides = new int64_t[ndim]; + int64_t *y_strides = new int64_t[ndim]; #pragma omp parallel for for (size_t i = 0; i < ndim; ++i) { x_strides[i] = (i < ndim - x->ndim || y->shape[i] != x->shape[i + x->ndim - ndim]) ? 0 : x->strides[i + x->ndim - ndim]; } + memcpy(y_strides, y->strides, ndim * sizeof(int64_t)); *desc_ptr = new ExpandCpuDescriptor{ DevCpu, @@ -29,7 +31,7 @@ infiniopStatus_t cpuCreateExpandDescriptor(infiniopHandle_t, ndim, y_data_size, x_strides, - y->strides, + y_strides, }; return STATUS_SUCCESS; @@ -37,6 +39,7 @@ infiniopStatus_t cpuCreateExpandDescriptor(infiniopHandle_t, infiniopStatus_t cpuDestroyExpandDescriptor(ExpandCpuDescriptor_t desc) { delete[] desc->x_strides; + delete[] desc->y_strides; delete desc; return STATUS_SUCCESS; } diff --git a/src/ops/expand/cpu/expand_cpu.h b/src/ops/expand/cpu/expand_cpu.h index c1796dc3..868fefe8 100644 --- a/src/ops/expand/cpu/expand_cpu.h +++ b/src/ops/expand/cpu/expand_cpu.h @@ -2,6 +2,7 @@ #define __CPU_EXPAND_H__ #include "operators.h" +#include #include struct ExpandCpuDescriptor { diff --git a/src/ops/expand/cuda/expand.cc b/src/ops/expand/cuda/expand.cc index deb171b0..a32be90a 100644 --- a/src/ops/expand/cuda/expand.cc +++ b/src/ops/expand/cuda/expand.cc @@ -26,13 +26,11 @@ infiniopStatus_t cudaCreateExpandDescriptor(CudaHandle_t handle, cudaGetDeviceProperties(&prop, handle->device_id); int64_t *x_strides_d, *y_strides_d; - uint64_t *y_shape_d; - checkCudaErrorWithCode(cudaMalloc(&x_strides_d, ndim * sizeof(int64_t)), STATUS_MEMORY_NOT_ALLOCATED); - checkCudaErrorWithCode(cudaMalloc(&y_strides_d, ndim * sizeof(int64_t)), STATUS_MEMORY_NOT_ALLOCATED); - checkCudaErrorWithCode(cudaMalloc(&y_shape_d, ndim * sizeof(uint64_t)), STATUS_MEMORY_NOT_ALLOCATED); - checkCudaErrorWithCode(cudaMemcpy(x_strides_d, x_strides, ndim * sizeof(int64_t), cudaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); - checkCudaErrorWithCode(cudaMemcpy(y_strides_d, y->strides, ndim * sizeof(int64_t), cudaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); - checkCudaErrorWithCode(cudaMemcpy(y_shape_d, y->shape, ndim * sizeof(uint64_t), cudaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + char *strides_and_shape_d; + checkCudaErrorWithCode(cudaMalloc(&strides_and_shape_d, ndim * (2 * sizeof(int64_t) + sizeof(uint64_t))), STATUS_MEMORY_NOT_ALLOCATED); + checkCudaErrorWithCode(cudaMemcpy(strides_and_shape_d, x_strides, ndim * sizeof(int64_t), cudaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaMemcpy(strides_and_shape_d + ndim * sizeof(int64_t), y->strides, ndim * sizeof(int64_t), cudaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaMemcpy(strides_and_shape_d + 2 * ndim * sizeof(int64_t), y->shape, ndim * sizeof(uint64_t), cudaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); *desc_ptr = new ExpandCudaDescriptor{ DevNvGpu, @@ -41,9 +39,7 @@ infiniopStatus_t cudaCreateExpandDescriptor(CudaHandle_t handle, ndim, y_data_size, static_cast(prop.maxGridSize[0]), - y_shape_d, - x_strides_d, - y_strides_d, + strides_and_shape_d, }; delete[] x_strides; @@ -52,9 +48,7 @@ infiniopStatus_t cudaCreateExpandDescriptor(CudaHandle_t handle, } infiniopStatus_t cudaDestroyExpandDescriptor(ExpandCudaDescriptor_t desc) { - cudaFree((void *) desc->x_strides); - cudaFree((void *) desc->y_strides); - cudaFree((void *) desc->y_shape); + cudaFree((void *) desc->strides_and_shape_d); delete desc; return STATUS_SUCCESS; } diff --git a/src/ops/expand/cuda/expand.cu b/src/ops/expand/cuda/expand.cu index d307e4d1..6d75e651 100644 --- a/src/ops/expand/cuda/expand.cu +++ b/src/ops/expand/cuda/expand.cu @@ -31,12 +31,15 @@ infiniopStatus_t expand_nv_gpu(ExpandCudaDescriptor_t desc, void *y, void const const auto x_ = reinterpret_cast(x); const auto y_ = reinterpret_cast(y); + const auto x_strides = reinterpret_cast(desc->strides_and_shape_d); + const auto y_strides = reinterpret_cast(desc->strides_and_shape_d + desc->ndim * sizeof(int64_t)); + const auto y_shape = reinterpret_cast(desc->strides_and_shape_d + 2 * desc->ndim * sizeof(int64_t)); cudaStream_t cuda_stream = reinterpret_cast(stream); #pragma unroll for (uint64_t i = 0; i < desc->y_data_size; i += step) { expand<<>>( - y_, x_, desc->y_strides, desc->x_strides, desc->y_shape, i + desc->y_data_size, desc->ndim, i); + y_, x_, y_strides, x_strides, y_shape, i + desc->y_data_size, desc->ndim, i); } return STATUS_SUCCESS; } diff --git a/src/ops/expand/cuda/expand.cuh b/src/ops/expand/cuda/expand.cuh index 0764243a..17cc1337 100644 --- a/src/ops/expand/cuda/expand.cuh +++ b/src/ops/expand/cuda/expand.cuh @@ -14,9 +14,7 @@ struct ExpandCudaDescriptor { uint64_t ndim; uint64_t y_data_size; uint64_t max_grid_size; - uint64_t const *y_shape; - int64_t const *x_strides; - int64_t const *y_strides; + char const *strides_and_shape_d; }; typedef struct ExpandCudaDescriptor *ExpandCudaDescriptor_t; diff --git a/src/ops/gemm/operator.cc b/src/ops/gemm/operator.cc index d22464f1..071c2870 100644 --- a/src/ops/gemm/operator.cc +++ b/src/ops/gemm/operator.cc @@ -28,8 +28,15 @@ __C __export infiniopStatus_t infiniopCreateGEMMDescriptor(infiniopHandle_t hand b_desc = transB ? permute(b_desc, {1, 0}) : b_desc; // expand desc - infiniopExpandDescriptor_t expand_desc = new ExpandDescriptor{handle->device}; - CHECK_STATUS(infiniopCreateExpandDescriptor(handle, &expand_desc, y_desc, c_desc), STATUS_SUCCESS); + infiniopExpandDescriptor_t expand_desc = nullptr; + + // c is optional, set beta to 0 when c is not provided + if (!c_desc || c_desc->ndim == 0 || c_desc->shape == nullptr || c_desc->shape[0] == 0) { + beta = 0; + } else { + expand_desc = new ExpandDescriptor{handle->device}; + CHECK_STATUS(infiniopCreateExpandDescriptor(handle, &expand_desc, y_desc, c_desc), STATUS_SUCCESS); + } // matmul desc infiniopMatmulDescriptor_t matmul_desc = new MatmulDescriptor{handle->device}; @@ -65,9 +72,11 @@ __C __export infiniopStatus_t infiniopGEMM(infiniopGEMMDescriptor_t desc, return STATUS_MEMORY_NOT_ALLOCATED; } - CHECK_STATUS(infiniopExpand(_desc->expand_desc, - y, c, stream), - STATUS_SUCCESS); + if (_desc->expand_desc != nullptr) { + CHECK_STATUS(infiniopExpand(_desc->expand_desc, + y, c, stream), + STATUS_SUCCESS); + } CHECK_STATUS(infiniopMatmul(_desc->matmul_desc, workspace, @@ -79,7 +88,9 @@ __C __export infiniopStatus_t infiniopGEMM(infiniopGEMMDescriptor_t desc, } __C __export infiniopStatus_t infiniopDestroyGEMMDescriptor(infiniopGEMMDescriptor_t desc) { + if (((_GEMMDescriptor_t) desc)->expand_desc) { + CHECK_STATUS(infiniopDestroyExpandDescriptor(((_GEMMDescriptor_t) desc)->expand_desc), STATUS_SUCCESS); + } CHECK_STATUS(infiniopDestroyMatmulDescriptor(((_GEMMDescriptor_t) desc)->matmul_desc), STATUS_SUCCESS); - CHECK_STATUS(infiniopDestroyExpandDescriptor(((_GEMMDescriptor_t) desc)->expand_desc), STATUS_SUCCESS); return STATUS_SUCCESS; } diff --git a/src/ops/matmul/cuda/matmul_cuda.cu b/src/ops/matmul/cuda/matmul_cuda.cu index b1f00726..a75b164e 100644 --- a/src/ops/matmul/cuda/matmul_cuda.cu +++ b/src/ops/matmul/cuda/matmul_cuda.cu @@ -26,7 +26,7 @@ infiniopStatus_t matmul_cuda(MatmulCudaDescriptor_t desc, void *c, float beta, v alpha_ = alpha; beta_ = beta; a_type = b_type = c_type = CUDA_R_32F; - compute_type = CUBLAS_COMPUTE_32F_FAST_16F; + compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; } auto op_a = info.a_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T; @@ -74,4 +74,4 @@ infiniopStatus_t cudaMatmul(MatmulCudaDescriptor_t desc, return matmul_cuda(desc, c, desc->beta, a, b, desc->alpha, stream); } return STATUS_BAD_TENSOR_DTYPE; -} \ No newline at end of file +} From 514cc2779610ae6adabd92b0913a411544fa8474 Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Tue, 5 Nov 2024 11:35:08 +0800 Subject: [PATCH 09/10] Add cudaDeviceProp and compute capability numbers into cuda handle --- src/devices/cuda/cuda_handle.cc | 20 +++++++++++++++++++- src/devices/cuda/cuda_handle.h | 6 +++++- src/ops/expand/cuda/expand.cc | 5 +---- 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/src/devices/cuda/cuda_handle.cc b/src/devices/cuda/cuda_handle.cc index e2475f0d..7d7db662 100644 --- a/src/devices/cuda/cuda_handle.cc +++ b/src/devices/cuda/cuda_handle.cc @@ -23,7 +23,25 @@ infiniopStatus_t createCudaHandle(CudaHandle_t *handle_ptr, int device_id) { checkCudnnError(cudnnCreate(&cudnn_handle)); cudnn_pool->push(std::move(cudnn_handle)); - *handle_ptr = new CudaContext{DevNvGpu, device_id, std::move(pool), std::move(cudnn_pool)}; + // set CUDA device property + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, device_id); + + // set device compute capability numbers + int capability_major; + int capability_minor; + cudaDeviceGetAttribute(&capability_major, cudaDevAttrComputeCapabilityMajor, device_id); + cudaDeviceGetAttribute(&capability_minor, cudaDevAttrComputeCapabilityMinor, device_id); + + *handle_ptr = new CudaContext{ + DevNvGpu, + device_id, + std::move(pool), + std::move(cudnn_pool), + std::move(prop), + capability_major, + capability_minor, + }; return STATUS_SUCCESS; } diff --git a/src/devices/cuda/cuda_handle.h b/src/devices/cuda/cuda_handle.h index 0df79cd0..aa293377 100644 --- a/src/devices/cuda/cuda_handle.h +++ b/src/devices/cuda/cuda_handle.h @@ -15,6 +15,9 @@ struct CudaContext { int device_id; std::shared_ptr> cublas_handles_t; std::shared_ptr> cudnn_handles_t; + cudaDeviceProp prop; + int compute_capability_major; + int compute_capability_minor; }; typedef struct CudaContext *CudaHandle_t; @@ -35,12 +38,13 @@ void use_cublas(std::shared_ptr> cublas_handles_t, int devi } template -cudnnStatus_t use_cudnn(std::shared_ptr> cudnn_handles_t, int device_id, T const &f) { +cudnnStatus_t use_cudnn(std::shared_ptr> cudnn_handles_t, int device_id, cudaStream_t stream, T const &f) { auto handle = cudnn_handles_t->pop(); if (!handle) { cudaSetDevice(device_id); cudnnCreate(&(*handle)); } + cudnnSetStream(*handle, stream); cudnnStatus_t status = f(*handle); cudnn_handles_t->push(std::move(*handle)); return status; diff --git a/src/ops/expand/cuda/expand.cc b/src/ops/expand/cuda/expand.cc index a32be90a..b93e78af 100644 --- a/src/ops/expand/cuda/expand.cc +++ b/src/ops/expand/cuda/expand.cc @@ -22,9 +22,6 @@ infiniopStatus_t cudaCreateExpandDescriptor(CudaHandle_t handle, x_strides[i] = (i < ndim - x->ndim || y->shape[i] != x->shape[i + x->ndim - ndim]) ? 0 : x->strides[i + x->ndim - ndim]; } - cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, handle->device_id); - int64_t *x_strides_d, *y_strides_d; char *strides_and_shape_d; checkCudaErrorWithCode(cudaMalloc(&strides_and_shape_d, ndim * (2 * sizeof(int64_t) + sizeof(uint64_t))), STATUS_MEMORY_NOT_ALLOCATED); @@ -38,7 +35,7 @@ infiniopStatus_t cudaCreateExpandDescriptor(CudaHandle_t handle, handle->device_id, ndim, y_data_size, - static_cast(prop.maxGridSize[0]), + static_cast(handle->prop.maxGridSize[0]), strides_and_shape_d, }; From 9e976bddf689987e4b89d006bcd4f988ac86cd70 Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Wed, 6 Nov 2024 16:24:30 +0800 Subject: [PATCH 10/10] Add checkCudaErrorWithCode to cudaDestroyDescriptor() for add and expand --- src/ops/add/cuda/add.cc | 6 +++--- src/ops/expand/cuda/expand.cc | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ops/add/cuda/add.cc b/src/ops/add/cuda/add.cc index bfb885c1..b010894f 100644 --- a/src/ops/add/cuda/add.cc +++ b/src/ops/add/cuda/add.cc @@ -73,9 +73,9 @@ infiniopStatus_t cudaCreateAddDescriptor(CudaHandle_t handle, } infiniopStatus_t cudaDestroyAddDescriptor(AddCudaDescriptor_t desc) { - cudaFree((void *) desc->a_strides); - cudaFree((void *) desc->b_strides); - cudaFree((void *) desc->c_strides); + checkCudaErrorWithCode(cudaFree((void *) desc->a_strides), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaFree((void *) desc->b_strides), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaFree((void *) desc->c_strides), STATUS_EXECUTION_FAILED); delete desc; return STATUS_SUCCESS; } diff --git a/src/ops/expand/cuda/expand.cc b/src/ops/expand/cuda/expand.cc index b93e78af..cf43b326 100644 --- a/src/ops/expand/cuda/expand.cc +++ b/src/ops/expand/cuda/expand.cc @@ -45,7 +45,7 @@ infiniopStatus_t cudaCreateExpandDescriptor(CudaHandle_t handle, } infiniopStatus_t cudaDestroyExpandDescriptor(ExpandCudaDescriptor_t desc) { - cudaFree((void *) desc->strides_and_shape_d); + checkCudaErrorWithCode(cudaFree((void *) desc->strides_and_shape_d), STATUS_EXECUTION_FAILED); delete desc; return STATUS_SUCCESS; }