diff --git a/include/infini_operators.h b/include/infini_operators.h index fee31719..2d316317 100644 --- a/include/infini_operators.h +++ b/include/infini_operators.h @@ -2,6 +2,8 @@ #include "ops/add/add.h" #include "ops/attention/attention.h" #include "ops/causal_softmax/causal_softmax.h" +#include "ops/expand/expand.h" +#include "ops/gemm/gemm.h" #include "ops/conv/conv.h" #include "ops/matmul/matmul.h" #include "ops/mlp/mlp.h" diff --git a/include/ops/expand/expand.h b/include/ops/expand/expand.h new file mode 100644 index 00000000..ee28b70c --- /dev/null +++ b/include/ops/expand/expand.h @@ -0,0 +1,25 @@ +#ifndef EXPAND_H +#define EXPAND_H + +#include "../../export.h" +#include "../../operators.h" + +typedef struct ExpandDescriptor { + Device device; +} ExpandDescriptor; + +typedef ExpandDescriptor *infiniopExpandDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateExpandDescriptor(infiniopHandle_t handle, + infiniopExpandDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x); + +__C __export infiniopStatus_t infiniopExpand(infiniopExpandDescriptor_t desc, + void *y, + void const *x, + void *stream); + +__C __export infiniopStatus_t infiniopDestroyExpandDescriptor(infiniopExpandDescriptor_t desc); + +#endif diff --git a/include/ops/gemm/gemm.h b/include/ops/gemm/gemm.h new file mode 100644 index 00000000..4a39da39 --- /dev/null +++ b/include/ops/gemm/gemm.h @@ -0,0 +1,36 @@ +#ifndef GEMM_H +#define GEMM_H + +#include "../../export.h" +#include "../../operators.h" + +typedef struct GEMMDescriptor { + Device device; +} GEMMDescriptor; + +typedef GEMMDescriptor *infiniopGEMMDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateGEMMDescriptor(infiniopHandle_t handle, + infiniopGEMMDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t c_desc, + float alpha, + float beta, + bool transA, + bool transB); + +__C __export infiniopStatus_t infiniopGetGEMMWorkspaceSize(infiniopGEMMDescriptor_t desc, uint64_t *size); + +__C __export infiniopStatus_t infiniopGEMM(infiniopGEMMDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *y, + void const *a, + void const *b, + void const *c, + void *stream); + +__C __export infiniopStatus_t infiniopDestroyGEMMDescriptor(infiniopGEMMDescriptor_t desc); +#endif diff --git a/operatorspy/tests/expand.py b/operatorspy/tests/expand.py new file mode 100644 index 00000000..15b3909d --- /dev/null +++ b/operatorspy/tests/expand.py @@ -0,0 +1,177 @@ +from ctypes import POINTER, Structure, c_int32, c_void_p +import ctypes +import sys +import os +import time + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from operatorspy import ( + open_lib, + to_tensor, + DeviceEnum, + infiniopHandle_t, + infiniopTensorDescriptor_t, + create_handle, + destroy_handle, + check_error, + rearrange_tensor, +) + +from operatorspy.tests.test_utils import get_args +import torch + +# constant for control whether profile the pytorch and lib functions +# NOTE: need to manually add synchronization function to the lib function, +# e.g., cudaDeviceSynchronize() for CUDA +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +class ExpandDescriptor(Structure): + _fields_ = [("device", c_int32)] + + +infiniopExpandDescriptor_t = POINTER(ExpandDescriptor) + + +def expand(x, y): + if PROFILE: + ans = x.expand_as(y).clone() + torch.cuda.synchronize() + return ans + return x.expand_as(y) + + +def test( + lib, + handle, + torch_device, + y_shape, + x_shape, + y_stride=None, + x_stride=None, + tensor_dtype=torch.float16, +): + print( + f"Testing Expand on {torch_device} with x_shape:{x_shape} y_shape:{y_shape} x_stride:{x_stride} y_stride:{y_stride} dtype:{tensor_dtype}" + ) + + x = torch.rand(x_shape, dtype=tensor_dtype).to(torch_device) + y = torch.rand(y_shape, dtype=tensor_dtype).to(torch_device) + + if x_stride is not None: + x = rearrange_tensor(x, x_stride) + if y_stride is not None: + y = rearrange_tensor(y, y_stride) + + for i in range(NUM_PRERUN if PROFILE else 1): + ans = expand(x, y) + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + _ = expand(x, y) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f"pytorch time: {elapsed :6f}") + + x_tensor = to_tensor(x, lib) + y_tensor = to_tensor(y, lib) + descriptor = infiniopExpandDescriptor_t() + + check_error( + lib.infiniopCreateExpandDescriptor( + handle, + ctypes.byref(descriptor), + y_tensor.descriptor, + x_tensor.descriptor, + ) + ) + + for i in range(NUM_PRERUN if PROFILE else 1): + lib.infiniopExpand( + descriptor, y_tensor.data, x_tensor.data, None + ) + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + lib.infiniopExpand( + descriptor, y_tensor.data, x_tensor.data, None + ) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f" lib time: {elapsed :6f}") + assert torch.allclose(y, ans, atol=0, rtol=1e-3) + check_error(lib.infiniopDestroyExpandDescriptor(descriptor)) + + +def test_cpu(lib, test_cases): + device = DeviceEnum.DEVICE_CPU + handle = create_handle(lib, device) + for y_shape, x_shape, y_stride, x_stride in test_cases: + test(lib, handle, "cpu", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float16) + test(lib, handle, "cpu", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float32) + destroy_handle(lib, handle) + + +def test_cuda(lib, test_cases): + device = DeviceEnum.DEVICE_CUDA + handle = create_handle(lib, device) + for y_shape, x_shape, y_stride, x_stride in test_cases: + test(lib, handle, "cuda", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float16) + test(lib, handle, "cuda", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float32) + destroy_handle(lib, handle) + + +def test_bang(lib, test_cases): + import torch_mlu + + device = DeviceEnum.DEVICE_BANG + handle = create_handle(lib, device) + for y_shape, x_shape, y_stride, x_stride in test_cases: + test(lib, handle, "mlu", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float16) + test(lib, handle, "mlu", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float32) + destroy_handle(lib, handle) + + +if __name__ == "__main__": + test_cases = [ + # y_shape, x_shape, y_stride, x_stride + ((), (), None, None), + ((3, 3), (1,), None, None), + ((5, 4, 3), (4, 3,), None, (6, 1)), + ((99, 111), (111,), None, None), + ((2, 4, 3), (1, 3), None, None), + ((2, 20, 3), (2, 1, 3), None, None), + ((2, 3, 4, 5), (5,), None, None), + ((3, 2, 4, 5), (3, 2, 1, 1), None, None), + ((32, 256, 112, 112), (32, 256, 112, 1), None, None), + ] + args = get_args() + lib = open_lib() + lib.infiniopCreateExpandDescriptor.restype = c_int32 + lib.infiniopCreateExpandDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopExpandDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + lib.infiniopExpand.restype = c_int32 + lib.infiniopExpand.argtypes = [ + infiniopExpandDescriptor_t, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyExpandDescriptor.restype = c_int32 + lib.infiniopDestroyExpandDescriptor.argtypes = [ + infiniopExpandDescriptor_t, + ] + + if args.cpu: + test_cpu(lib, test_cases) + if args.cuda: + test_cuda(lib, test_cases) + if args.bang: + test_bang(lib, test_cases) + if not (args.cpu or args.cuda or args.bang): + test_cpu(lib, test_cases) + print("\033[92mTest passed!\033[0m") diff --git a/operatorspy/tests/gemm.py b/operatorspy/tests/gemm.py new file mode 100644 index 00000000..3fce2394 --- /dev/null +++ b/operatorspy/tests/gemm.py @@ -0,0 +1,366 @@ +from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float, c_bool +import ctypes +import sys +import os +import time + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from operatorspy import ( + open_lib, + to_tensor, + DeviceEnum, + infiniopHandle_t, + infiniopTensorDescriptor_t, + create_handle, + destroy_handle, + check_error, + rearrange_tensor, +) + +from operatorspy.tests.test_utils import get_args +import torch + +# constant for control whether profile the pytorch and lib functions +# NOTE: need to manually add synchronization function to the lib function, +# e.g., cudaDeviceSynchronize() for CUDA +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + +class GEMMDescriptor(Structure): + _fields_ = [("device", c_int32)] + + +infiniopGEMMDescriptor_t = POINTER(GEMMDescriptor) + + +def gemm(A, B, C=None, transA=False, transB=False, alpha=1.0, beta=0.0, dtype=torch.float32): + A = A.T if transA else A + B = B.T if transB else B + result = alpha * torch.matmul(A if dtype != torch.float16 else A.to(torch.float32), B if dtype != torch.float16 else B.to(torch.float32)).to(dtype) + if C is not None: + result += beta * C if dtype != torch.float16 else C.to(torch.float32) + if PROFILE: + torch.cuda.synchronize() + return result + + +def test( + lib, + handle, + torch_device, + alpha, + beta, + transA, + transB, + a_shape, + b_shape, + c_shape, + y_shape, + a_stride=None, + b_stride=None, + c_stride=None, + y_stride=None, + dtype=torch.float16, +): + print( + f"Testing GEMM on {torch_device} with transA: {transA} transB: {transB} " + f"a_shape:{a_shape} b_shape:{b_shape} c_shape:{c_shape} y_shape:{y_shape} " + f"a_stride:{a_stride} b_stride:{b_stride} c_stride:{c_stride} y_stride:{y_stride} dtype:{dtype}" + ) + + a = torch.rand(a_shape, dtype=dtype).to(torch_device) + b = torch.rand(b_shape, dtype=dtype).to(torch_device) + c = torch.rand(c_shape, dtype=dtype).to(torch_device) if c_shape else None + y = torch.rand(y_shape, dtype=dtype).to(torch_device) + + if a_stride is not None: + a = rearrange_tensor(a, a_stride) + if b_stride is not None: + b = rearrange_tensor(b, b_stride) + if c_stride is not None and c is not None: + c = rearrange_tensor(c, c_stride) + if y_stride is not None: + y = rearrange_tensor(y, y_stride) + + for i in range(NUM_PRERUN if PROFILE else 1): + ans = gemm(a, b, c, transA, transB, alpha, beta, dtype) + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + _ = gemm(a, b, c, transA, transB, alpha, beta, dtype) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f"pytorch time: {elapsed :6f}") + + + a_tensor = to_tensor(a, lib) + b_tensor = to_tensor(b, lib) + c_tensor = to_tensor(c, lib) if c is not None else None + y_tensor = to_tensor(y, lib) + descriptor = infiniopGEMMDescriptor_t() + check_error( + lib.infiniopCreateGEMMDescriptor( + handle, + ctypes.byref(descriptor), + y_tensor.descriptor, + a_tensor.descriptor, + b_tensor.descriptor, + c_tensor.descriptor if c_tensor else None, + alpha, + beta, + transA, + transB, + ) + ) + + workspace_size = ctypes.c_uint64(0) + check_error( + lib.infiniopGetGEMMWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = torch.zeros(int(workspace_size.value), dtype=torch.uint8).to( + torch_device + ) + workspace_ptr = ctypes.cast(workspace.data_ptr(), ctypes.POINTER(ctypes.c_uint8)) + + for i in range(NUM_PRERUN if PROFILE else 1): + check_error( + lib.infiniopGEMM( + descriptor, + workspace_ptr, + workspace_size, + y_tensor.data, + a_tensor.data, + b_tensor.data, + c_tensor.data if c_tensor else None, + None, + ) + ) + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + lib.infiniopGEMM( + descriptor, + workspace_ptr, + workspace_size, + y_tensor.data, + a_tensor.data, + b_tensor.data, + c_tensor.data if c_tensor else None, + None, + ) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f" lib time: {elapsed :6f}") + + assert torch.allclose(y, ans, atol=0, rtol=1e-2) + check_error(lib.infiniopDestroyGEMMDescriptor(descriptor)) + + +def test_cpu(lib, test_cases): + device = DeviceEnum.DEVICE_CPU + handle = create_handle(lib, device) + for ( + alpha, + beta, + transA, + transB, + a_shape, + b_shape, + c_shape, + y_shape, + a_stride, + b_stride, + c_stride, + y_stride, + ) in test_cases: + test(lib, handle, "cpu", alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride, dtype=torch.float16) + test(lib, handle, "cpu", alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride, dtype=torch.float32) + destroy_handle(lib, handle) + + +def test_cuda(lib, test_cases): + device = DeviceEnum.DEVICE_CUDA + handle = create_handle(lib, device) + for ( + alpha, + beta, + transA, + transB, + a_shape, + b_shape, + c_shape, + y_shape, + a_stride, + b_stride, + c_stride, + y_stride, + ) in test_cases: + test(lib, handle, "cuda", alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride, dtype=torch.float16) + test(lib, handle, "cuda", alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride, dtype=torch.float32) + destroy_handle(lib, handle) + + +def test_bang(lib, test_cases): + import torch_mlu + + device = DeviceEnum.DEVICE_BANG + handle = create_handle(lib, device) + + for ( + alpha, + beta, + transA, + transB, + a_shape, + b_shape, + c_shape, + y_shape, + a_stride, + b_stride, + c_stride, + y_stride, + ) in test_cases: + test(lib, handle, "mlu", alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride, dtype=torch.float16) + test(lib, handle, "mlu", alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride, dtype=torch.float32) + + destroy_handle(lib, handle) + + +if __name__ == "__main__": + test_cases = [ + # alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride + ( + 1.0, + 1.0, + False, + False, + (1, 2048), + (2048, 2048), + (1, 2048), + (1, 2048), + None, + None, + None, + None, + ), + ( + 1.0, + 1.0, + True, + True, + (2048, 4), + (2048, 2048), + (4, 2048), + (4, 2048), + None, + None, + None, + None, + ), + ( + 1.0, + 1.0, + False, + True, + (1, 2048), + (1000, 2048), + (1000), + (1, 1000), + None, + None, + None, + None, + ), + ( + 1.0, + 1.0, + True, + False, + (2048, 4), + (2048, 2048), + (2048), + (4, 2048), + (4096, 1), + (4096, 1), + (2,), + (4096, 1), + ), + ( + 1.0, + 1.0, + False, + False, + (3, 1, 2048), + (3, 2048, 2048), + (1,), + (3, 1, 2048), + None, + None, + None, + None, + ), + ( + 1.0, + 1.0, + True, + False, + (2048, 4), + (2048, 2048), + None, + (4, 2048), + (4096, 1), + (4096, 1), + (2,), + (4096, 1), + ), + ] + args = get_args() + lib = open_lib() + + lib.infiniopCreateGEMMDescriptor.restype = c_int32 + lib.infiniopCreateGEMMDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopGEMMDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_float, + c_float, + c_bool, + c_bool, + ] + + lib.infiniopGetGEMMWorkspaceSize.restype = c_int32 + lib.infiniopGetGEMMWorkspaceSize.argtypes = [ + infiniopGEMMDescriptor_t, + POINTER(c_uint64), + ] + + lib.infiniopGEMM.restype = c_int32 + lib.infiniopGEMM.argtypes = [ + infiniopGEMMDescriptor_t, + c_void_p, + c_uint64, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyGEMMDescriptor.restype = c_int32 + lib.infiniopDestroyGEMMDescriptor.argtypes = [ + infiniopGEMMDescriptor_t, + ] + + if args.cpu: + test_cpu(lib, test_cases) + if args.cuda: + test_cuda(lib, test_cases) + if args.bang: + test_bang(lib, test_cases) + if not (args.cpu or args.cuda or args.bang): + test_cpu(lib, test_cases) + print("\033[92mTest passed!\033[0m") diff --git a/operatorspy/tests/matmul.py b/operatorspy/tests/matmul.py index 3dc2a9ce..67daf48c 100644 --- a/operatorspy/tests/matmul.py +++ b/operatorspy/tests/matmul.py @@ -2,6 +2,7 @@ import ctypes import sys import os +import time sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) from operatorspy import ( @@ -21,6 +22,13 @@ from operatorspy.tests.test_utils import get_args import torch +# constant for control whether profile the pytorch and lib functions +# NOTE: need to manually add synchronization function to the lib function, +# e.g., cudaDeviceSynchronize() for CUDA +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + class MatmulDescriptor(Structure): _fields_ = [("device", c_int32)] @@ -30,10 +38,13 @@ class MatmulDescriptor(Structure): def matmul(c, beta, a, b, alpha): input_dtype = c.dtype - return ( + ans = ( alpha * torch.matmul(a.to(torch.float32), b.to(torch.float32)).to(input_dtype) + beta * c ) + if PROFILE: + torch.cuda.synchronize() + return ans def test( @@ -66,7 +77,15 @@ def test( if c_stride is not None: c = rearrange_tensor(c, c_stride) - ans = matmul(c, beta, a, b, alpha) + for i in range(NUM_PRERUN if PROFILE else 1): + ans = matmul(c, beta, a, b, alpha) + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + _ = matmul(c, beta, a, b, alpha) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f"pytorch time: {elapsed :6f}") + a_tensor = to_tensor(a, lib) b_tensor = to_tensor(b, lib) @@ -90,7 +109,8 @@ def test( ) workspace = create_workspace(workspace_size.value, a.device) - check_error( + for i in range(NUM_PRERUN if PROFILE else 1): + check_error( lib.infiniopMatmul( descriptor, workspace.data_ptr() if workspace is not None else None, @@ -101,6 +121,20 @@ def test( None, ) ) + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + lib.infiniopMatmul( + descriptor, + workspace.data_ptr() if workspace is not None else None, + workspace_size.value, + c_tensor.data, + a_tensor.data, + b_tensor.data, + None, + ) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f" lib time: {elapsed :6f}") assert torch.allclose(c, ans, atol=0, rtol=1e-2) @@ -244,28 +278,11 @@ def test_ascend(lib, test_cases): test_cases = [ # alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride, dtype (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None, torch.float16), - ( - 1.0, - 0.0, - (1, 2048), - (2048, 2048), - (1, 2048), - (4096, 1), - (4096, 1), - (4096, 1), - torch.float16, - ), - ( - 1.0, - 0.0, - (2, 1, 2048), - (2, 2048, 2048), - (2, 1, 2048), - None, - None, - None, - torch.float16, - ), + (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None, torch.float32), + (1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None, torch.float16), + (1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None, torch.float32), + (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1), torch.float16), + (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1), torch.float32), ] args = get_args() lib = open_lib() @@ -313,4 +330,4 @@ def test_ascend(lib, test_cases): test_ascend(lib, test_cases) if not (args.cpu or args.cuda or args.bang or args.ascend): test_cpu(lib, test_cases) - print("Test passed!") + print("\033[92mTest passed!\033[0m") diff --git a/src/devices/cpu/common_cpu.cc b/src/devices/cpu/common_cpu.cc index c89c7491..b5b5f0fd 100644 --- a/src/devices/cpu/common_cpu.cc +++ b/src/devices/cpu/common_cpu.cc @@ -65,3 +65,21 @@ uint16_t f32_to_f16(float val) { return sign; } } + +uint64_t getDstOffset(uint64_t flat_index, uint64_t ndim, int64_t const *src_strides, int64_t const *dst_strides) { + uint64_t res = 0; + for (uint64_t i = 0; i < ndim; ++i) { + res += flat_index / src_strides[i] * dst_strides[i]; + flat_index %= src_strides[i]; + } + return res; +} + +uint64_t getOffset(uint64_t flat_index, uint64_t ndim, uint64_t const *shape, int64_t const *strides) { + uint64_t res = 0; + for (long i = ndim - 1; i >= 0; --i) { + res += (flat_index % shape[i]) * strides[i]; + flat_index /= shape[i]; + } + return res; +} diff --git a/src/devices/cpu/common_cpu.h b/src/devices/cpu/common_cpu.h index 20f1a2d8..caf3dd73 100644 --- a/src/devices/cpu/common_cpu.h +++ b/src/devices/cpu/common_cpu.h @@ -15,4 +15,10 @@ float f16_to_f32(uint16_t code); // convert single-precision float to half-precision float uint16_t f32_to_f16(float val); -#endif // __COMMON_CPU_H__ +// get the corresponding offset in the destination given the flat index of the source (for element mapping in shape broadcast) +uint64_t getDstOffset(uint64_t flat_index, uint64_t ndim, int64_t const *src_strides, int64_t const *dst_strides); + +// get the memory offset of the given element in a tensor given its flat index +uint64_t getOffset(uint64_t flat_index, uint64_t ndim, uint64_t const *shape, int64_t const *strides); + +#endif// __COMMON_CPU_H__ diff --git a/src/devices/cuda/common_cuda.h b/src/devices/cuda/common_cuda.h index fa89e6c6..3bd7e856 100644 --- a/src/devices/cuda/common_cuda.h +++ b/src/devices/cuda/common_cuda.h @@ -54,4 +54,24 @@ typedef struct DataLayoutMap { constexpr DTMap dataTypeMap; +// get the corresponding offset in the destination given the flat index of the source (for element mapping in shape broadcast) +inline __device__ uint64_t getDstOffset(uint64_t flat_index, uint64_t ndim, int64_t const *src_strides, int64_t const *dst_strides) { + uint64_t res = 0; + for (uint64_t i = 0; i < ndim; ++i) { + res += flat_index / src_strides[i] * dst_strides[i]; + flat_index %= src_strides[i]; + } + return res; +} + +// get the memory offset of the given element in a tensor given its flat index +inline __device__ uint64_t getOffset(uint64_t flat_index, uint64_t ndim, uint64_t const *shape, int64_t const *strides) { + uint64_t res = 0; + for (long i = ndim - 1; i >= 0; --i) { + res += (flat_index % shape[i]) * strides[i]; + flat_index /= shape[i]; + } + return res; +} + #endif// __COMMON_CUDA_H__ diff --git a/src/ops/add/cuda/add.cc b/src/ops/add/cuda/add.cc index bfb885c1..b010894f 100644 --- a/src/ops/add/cuda/add.cc +++ b/src/ops/add/cuda/add.cc @@ -73,9 +73,9 @@ infiniopStatus_t cudaCreateAddDescriptor(CudaHandle_t handle, } infiniopStatus_t cudaDestroyAddDescriptor(AddCudaDescriptor_t desc) { - cudaFree((void *) desc->a_strides); - cudaFree((void *) desc->b_strides); - cudaFree((void *) desc->c_strides); + checkCudaErrorWithCode(cudaFree((void *) desc->a_strides), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaFree((void *) desc->b_strides), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaFree((void *) desc->c_strides), STATUS_EXECUTION_FAILED); delete desc; return STATUS_SUCCESS; } diff --git a/src/ops/add/cuda/add.cu b/src/ops/add/cuda/add.cu index 6c1dfec4..9d9aefcb 100644 --- a/src/ops/add/cuda/add.cu +++ b/src/ops/add/cuda/add.cu @@ -35,16 +35,6 @@ struct vecN { } }; -// get the corresponding index in the destination given the flat index of the source -__device__ uint64_t getDstIndex(uint64_t flat_index, uint64_t ndim, int64_t const *src_strides, int64_t const *dst_strides) { - uint64_t res = 0; - for (uint64_t i = 0; i < ndim; ++i) { - res += flat_index / src_strides[i] * dst_strides[i]; - flat_index %= src_strides[i]; - } - return res; -} - template __global__ void add( Tdata *c, @@ -68,8 +58,8 @@ __global__ void add( auto c_ = reinterpret_cast(c); #pragma unroll for (size_t i = 0; i < pack_size; ++i) { - auto a_idx = getDstIndex(idx + i, ndim, c_strides, a_strides); - auto b_idx = getDstIndex(idx + i, ndim, c_strides, b_strides); + auto a_idx = getDstOffset(idx + i, ndim, c_strides, a_strides); + auto b_idx = getDstOffset(idx + i, ndim, c_strides, b_strides); c_[idx + i] = a_[a_idx] + b_[b_idx]; } return; diff --git a/src/ops/expand/cpu/expand_cpu.cc b/src/ops/expand/cpu/expand_cpu.cc new file mode 100644 index 00000000..d3bcb866 --- /dev/null +++ b/src/ops/expand/cpu/expand_cpu.cc @@ -0,0 +1,69 @@ +#include "expand_cpu.h" +#include "../../../devices/cpu/common_cpu.h" +#include "../../utils.h" + +infiniopStatus_t cpuCreateExpandDescriptor(infiniopHandle_t, + ExpandCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x) { + uint64_t ndim = y->ndim; + if (!isValidBroadcastShape(y, x)) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (y->dt != x->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + + uint64_t y_data_size = std::accumulate(y->shape, y->shape + y->ndim, 1ULL, std::multiplies()); + + // get the adjusted strides for x in terms of y + int64_t *x_strides = new int64_t[ndim]; + int64_t *y_strides = new int64_t[ndim]; +#pragma omp parallel for + for (size_t i = 0; i < ndim; ++i) { + x_strides[i] = (i < ndim - x->ndim || y->shape[i] != x->shape[i + x->ndim - ndim]) ? 0 : x->strides[i + x->ndim - ndim]; + } + memcpy(y_strides, y->strides, ndim * sizeof(int64_t)); + + *desc_ptr = new ExpandCpuDescriptor{ + DevCpu, + y->dt, + ndim, + y_data_size, + x_strides, + y_strides, + }; + + return STATUS_SUCCESS; +} + +infiniopStatus_t cpuDestroyExpandDescriptor(ExpandCpuDescriptor_t desc) { + delete[] desc->x_strides; + delete[] desc->y_strides; + delete desc; + return STATUS_SUCCESS; +} + +template +infiniopStatus_t expand_cpu(ExpandCpuDescriptor_t desc, void *y, void const *x) { + auto x_ = reinterpret_cast(x); + auto y_ = reinterpret_cast(y); + +#pragma omp parallel for + for (uint64_t i = 0; i < desc->y_data_size; ++i) { + y_[i] = x_[getDstOffset(i, desc->ndim, desc->y_strides, desc->x_strides)]; + } + return STATUS_SUCCESS; +} + +infiniopStatus_t cpuExpand(ExpandCpuDescriptor_t desc, + void *y, void const *x, + void *stream) { + if (desc->dtype == F16) { + return expand_cpu(desc, y, x); + } + if (desc->dtype == F32) { + return expand_cpu(desc, y, x); + } + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/expand/cpu/expand_cpu.h b/src/ops/expand/cpu/expand_cpu.h new file mode 100644 index 00000000..868fefe8 --- /dev/null +++ b/src/ops/expand/cpu/expand_cpu.h @@ -0,0 +1,29 @@ +#ifndef __CPU_EXPAND_H__ +#define __CPU_EXPAND_H__ + +#include "operators.h" +#include +#include + +struct ExpandCpuDescriptor { + Device device; + DT dtype; + uint64_t ndim; + uint64_t y_data_size; + int64_t const *x_strides; + int64_t const *y_strides; +}; + +typedef struct ExpandCpuDescriptor *ExpandCpuDescriptor_t; + +infiniopStatus_t cpuCreateExpandDescriptor(infiniopHandle_t, + ExpandCpuDescriptor_t *, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x); + +infiniopStatus_t cpuExpand(ExpandCpuDescriptor_t desc, + void *y, void const *x, void *stream); + +infiniopStatus_t cpuDestroyExpandDescriptor(ExpandCpuDescriptor_t desc); + +#endif diff --git a/src/ops/expand/cuda/expand.cc b/src/ops/expand/cuda/expand.cc new file mode 100644 index 00000000..cf43b326 --- /dev/null +++ b/src/ops/expand/cuda/expand.cc @@ -0,0 +1,51 @@ +#include "expand.cuh" +#include "../../../devices/cuda/common_cuda.h" +#include "../../utils.h" + +infiniopStatus_t cudaCreateExpandDescriptor(CudaHandle_t handle, + ExpandCudaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x) { + uint64_t ndim = y->ndim; + if (!isValidBroadcastShape(y, x)) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (y->dt != x->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + + uint64_t y_data_size = std::accumulate(y->shape, y->shape + y->ndim, 1ULL, std::multiplies()); + + // get the adjusted strides for x in terms of y + int64_t *x_strides = new int64_t[ndim]; + for (size_t i = 0; i < ndim; ++i) { + x_strides[i] = (i < ndim - x->ndim || y->shape[i] != x->shape[i + x->ndim - ndim]) ? 0 : x->strides[i + x->ndim - ndim]; + } + + int64_t *x_strides_d, *y_strides_d; + char *strides_and_shape_d; + checkCudaErrorWithCode(cudaMalloc(&strides_and_shape_d, ndim * (2 * sizeof(int64_t) + sizeof(uint64_t))), STATUS_MEMORY_NOT_ALLOCATED); + checkCudaErrorWithCode(cudaMemcpy(strides_and_shape_d, x_strides, ndim * sizeof(int64_t), cudaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaMemcpy(strides_and_shape_d + ndim * sizeof(int64_t), y->strides, ndim * sizeof(int64_t), cudaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaMemcpy(strides_and_shape_d + 2 * ndim * sizeof(int64_t), y->shape, ndim * sizeof(uint64_t), cudaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + + *desc_ptr = new ExpandCudaDescriptor{ + DevNvGpu, + y->dt, + handle->device_id, + ndim, + y_data_size, + static_cast(handle->prop.maxGridSize[0]), + strides_and_shape_d, + }; + + delete[] x_strides; + + return STATUS_SUCCESS; +} + +infiniopStatus_t cudaDestroyExpandDescriptor(ExpandCudaDescriptor_t desc) { + checkCudaErrorWithCode(cudaFree((void *) desc->strides_and_shape_d), STATUS_EXECUTION_FAILED); + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/expand/cuda/expand.cu b/src/ops/expand/cuda/expand.cu new file mode 100644 index 00000000..6d75e651 --- /dev/null +++ b/src/ops/expand/cuda/expand.cu @@ -0,0 +1,58 @@ +#include "../../../devices/cuda/common_cuda.h" +#include "../../utils.h" +#include "expand.cuh" + +template +__global__ void expand( + Tdata *y, + const Tdata *x, + const int64_t *y_strides, + const int64_t *x_strides, + const uint64_t *y_shape, + uint64_t y_data_size, + uint64_t ndim, + uint64_t offset) { + uint64_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset; + + if (idx < y_data_size) { + uint64_t y_idx = getOffset(idx, ndim, y_shape, y_strides); + y[y_idx] = x[getDstOffset(y_idx, ndim, y_strides, x_strides)]; + } +} + +template +infiniopStatus_t expand_nv_gpu(ExpandCudaDescriptor_t desc, void *y, void const *x, void *stream) { + if (desc->y_data_size == 0) { + return STATUS_SUCCESS; + } + dim3 blockDims = dim3(std::min(static_cast(256), desc->y_data_size)); + dim3 gridDims = dim3(std::min(ROUND_UP_DIV(desc->y_data_size, blockDims.x), desc->max_grid_size)); + uint64_t step = gridDims.x * blockDims.x; + + const auto x_ = reinterpret_cast(x); + const auto y_ = reinterpret_cast(y); + const auto x_strides = reinterpret_cast(desc->strides_and_shape_d); + const auto y_strides = reinterpret_cast(desc->strides_and_shape_d + desc->ndim * sizeof(int64_t)); + const auto y_shape = reinterpret_cast(desc->strides_and_shape_d + 2 * desc->ndim * sizeof(int64_t)); + cudaStream_t cuda_stream = reinterpret_cast(stream); + +#pragma unroll + for (uint64_t i = 0; i < desc->y_data_size; i += step) { + expand<<>>( + y_, x_, y_strides, x_strides, y_shape, i + desc->y_data_size, desc->ndim, i); + } + return STATUS_SUCCESS; +} + +infiniopStatus_t cudaExpand(ExpandCudaDescriptor_t desc, + void *y, void const *x, + void *stream) { + checkCudaError(cudaSetDevice(desc->device_id)); + if (desc->dtype == F16) { + return expand_nv_gpu(desc, y, x, stream); + } + if (desc->dtype == F32) { + return expand_nv_gpu(desc, y, x, stream); + } + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/expand/cuda/expand.cuh b/src/ops/expand/cuda/expand.cuh new file mode 100644 index 00000000..17cc1337 --- /dev/null +++ b/src/ops/expand/cuda/expand.cuh @@ -0,0 +1,33 @@ +#ifndef __CUDA_EXPAND_H__ +#define __CUDA_EXPAND_H__ + +#include "../../../devices/cuda/common_cuda.h" +#include "../../../devices/cuda/cuda_handle.h" +#include "operators.h" +#include +#include + +struct ExpandCudaDescriptor { + Device device; + DT dtype; + int device_id; + uint64_t ndim; + uint64_t y_data_size; + uint64_t max_grid_size; + char const *strides_and_shape_d; +}; + +typedef struct ExpandCudaDescriptor *ExpandCudaDescriptor_t; + +infiniopStatus_t cudaCreateExpandDescriptor(CudaHandle_t, + ExpandCudaDescriptor_t *, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x); + +infiniopStatus_t cudaExpand(ExpandCudaDescriptor_t desc, + void *y, void const *x, + void *stream); + +infiniopStatus_t cudaDestroyExpandDescriptor(ExpandCudaDescriptor_t desc); + +#endif diff --git a/src/ops/expand/operator.cc b/src/ops/expand/operator.cc new file mode 100644 index 00000000..0572acd0 --- /dev/null +++ b/src/ops/expand/operator.cc @@ -0,0 +1,72 @@ +#include "../utils.h" +#include "operators.h" +#include "ops/expand/expand.h" + +#ifdef ENABLE_CPU +#include "cpu/expand_cpu.h" +#endif +#ifdef ENABLE_NV_GPU +#include "../../devices/cuda/cuda_handle.h" +#include "cuda/expand.cuh" +#endif + +__C infiniopStatus_t infiniopCreateExpandDescriptor( + infiniopHandle_t handle, + infiniopExpandDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x) { + switch (handle->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuCreateExpandDescriptor(handle, (ExpandCpuDescriptor_t *) desc_ptr, y, x); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaCreateExpandDescriptor((CudaHandle_t) handle, (ExpandCudaDescriptor_t *) desc_ptr, y, x); + } + +#endif +#ifdef ENABLE_CAMBRICON_MLU + // TODO +#endif + } + return STATUS_BAD_DEVICE; +} + +__C infiniopStatus_t infiniopExpand(infiniopExpandDescriptor_t desc, void *y, void const *x, void *stream) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuExpand((ExpandCpuDescriptor_t) desc, y, x, stream); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaExpand((ExpandCudaDescriptor_t) desc, y, x, stream); + } + +#endif +#ifdef ENABLE_CAMBRICON_MLU + // TODO +#endif + } + return STATUS_BAD_DEVICE; +} + +__C infiniopStatus_t infiniopDestroyExpandDescriptor(infiniopExpandDescriptor_t desc) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuDestroyExpandDescriptor((ExpandCpuDescriptor_t) desc); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaDestroyExpandDescriptor((ExpandCudaDescriptor_t) desc); + } + +#endif +#ifdef ENABLE_CAMBRICON_MLU + // TODO +#endif + } + return STATUS_BAD_DEVICE; +} diff --git a/src/ops/gemm/operator.cc b/src/ops/gemm/operator.cc new file mode 100644 index 00000000..071c2870 --- /dev/null +++ b/src/ops/gemm/operator.cc @@ -0,0 +1,96 @@ +#include "../utils.h" +#include "ops/expand/expand.h" +#include "ops/gemm/gemm.h" +#include "ops/matmul/matmul.h" +#include "tensor/tensor_descriptor.h" + +struct _GEMMDescriptor { + Device device; + infiniopMatmulDescriptor_t matmul_desc; + infiniopExpandDescriptor_t expand_desc; + uint64_t workspace_size; +}; + +typedef struct _GEMMDescriptor *_GEMMDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateGEMMDescriptor(infiniopHandle_t handle, + infiniopGEMMDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t c_desc, + float alpha, + float beta, + bool transA, + bool transB) { + // transpose a and b if needed + a_desc = transA ? permute(a_desc, {1, 0}) : a_desc; + b_desc = transB ? permute(b_desc, {1, 0}) : b_desc; + + // expand desc + infiniopExpandDescriptor_t expand_desc = nullptr; + + // c is optional, set beta to 0 when c is not provided + if (!c_desc || c_desc->ndim == 0 || c_desc->shape == nullptr || c_desc->shape[0] == 0) { + beta = 0; + } else { + expand_desc = new ExpandDescriptor{handle->device}; + CHECK_STATUS(infiniopCreateExpandDescriptor(handle, &expand_desc, y_desc, c_desc), STATUS_SUCCESS); + } + + // matmul desc + infiniopMatmulDescriptor_t matmul_desc = new MatmulDescriptor{handle->device}; + CHECK_STATUS(infiniopCreateMatmulDescriptor(handle, &matmul_desc, y_desc, alpha, a_desc, b_desc, beta), STATUS_SUCCESS); + uint64_t workspace_size = 0; + CHECK_STATUS(infiniopGetMatmulWorkspaceSize(matmul_desc, &workspace_size), STATUS_SUCCESS); + + *(_GEMMDescriptor_t *) desc_ptr = new _GEMMDescriptor{ + handle->device, + matmul_desc, + expand_desc, + workspace_size, + }; + + return STATUS_SUCCESS; +} + +__C __export infiniopStatus_t infiniopGetGEMMWorkspaceSize(infiniopGEMMDescriptor_t desc, uint64_t *size) { + *size = ((_GEMMDescriptor_t) desc)->workspace_size; + return STATUS_SUCCESS; +} + +__C __export infiniopStatus_t infiniopGEMM(infiniopGEMMDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *y, + void const *a, + void const *b, + void const *c, + void *stream) { + auto _desc = (_GEMMDescriptor_t) desc; + if (workspace_size < _desc->workspace_size) { + return STATUS_MEMORY_NOT_ALLOCATED; + } + + if (_desc->expand_desc != nullptr) { + CHECK_STATUS(infiniopExpand(_desc->expand_desc, + y, c, stream), + STATUS_SUCCESS); + } + + CHECK_STATUS(infiniopMatmul(_desc->matmul_desc, + workspace, + workspace_size, + y, a, b, stream), + STATUS_SUCCESS); + + return STATUS_SUCCESS; +} + +__C __export infiniopStatus_t infiniopDestroyGEMMDescriptor(infiniopGEMMDescriptor_t desc) { + if (((_GEMMDescriptor_t) desc)->expand_desc) { + CHECK_STATUS(infiniopDestroyExpandDescriptor(((_GEMMDescriptor_t) desc)->expand_desc), STATUS_SUCCESS); + } + CHECK_STATUS(infiniopDestroyMatmulDescriptor(((_GEMMDescriptor_t) desc)->matmul_desc), STATUS_SUCCESS); + return STATUS_SUCCESS; +} diff --git a/src/ops/matmul/cpu/matmul_cpu.cc b/src/ops/matmul/cpu/matmul_cpu.cc index 88ced7a1..b6148852 100644 --- a/src/ops/matmul/cpu/matmul_cpu.cc +++ b/src/ops/matmul/cpu/matmul_cpu.cc @@ -12,7 +12,7 @@ infiniopStatus_t cpuCreateMatmulDescriptor(CpuHandle_t handle, float beta) { DT dtype = c_desc->dt; - if (!dtype_eq(dtype, F16)) { + if (dtype != F16 && dtype != F32) { return STATUS_BAD_TENSOR_DTYPE; } @@ -31,20 +31,6 @@ infiniopStatus_t cpuCreateMatmulDescriptor(CpuHandle_t handle, return STATUS_SUCCESS; } -infiniopStatus_t cpuMatmul(MatmulCpuDescriptor_t desc, - void *workspace, - uint64_t workspace_size, - void *c, - void const *a, - void const *b) { - if (dtype_eq(desc->dtype, F16)) { - matmul_cpu_f16(desc, c, desc->beta, a, b, desc->alpha); - return STATUS_SUCCESS; - } - - return STATUS_BAD_TENSOR_DTYPE; -} - infiniopStatus_t cpuGetMatmulWorkspaceSize(MatmulCpuDescriptor_t desc, uint64_t *size) { *size = 0; return STATUS_SUCCESS; @@ -55,7 +41,8 @@ infiniopStatus_t cpuDestroyMatmulDescriptor(MatmulCpuDescriptor_t desc) { return STATUS_SUCCESS; } -void matmul_cpu_f16(MatmulCpuDescriptor_t desc, void *c, float beta, void const *a, void const *b, float alpha) { +template +infiniopStatus_t matmul_cpu(MatmulCpuDescriptor_t desc, void *c, float beta, void const *a, void const *b, float alpha) { auto info = desc->info; if (info.is_transed) { @@ -65,15 +52,39 @@ void matmul_cpu_f16(MatmulCpuDescriptor_t desc, void *c, float beta, void const for (int i = 0; i < info.batch; ++i) { for (int m_ = 0; m_ < info.m; ++m_) { for (int n_ = 0; n_ < info.n; ++n_) { - auto c_ = reinterpret_cast(c) + i * info.c_matrix.stride + m_ * info.c_matrix.row_stride + n_ * info.c_matrix.col_stride; + auto c_ = reinterpret_cast(c) + i * info.c_matrix.stride + m_ * info.c_matrix.row_stride + n_ * info.c_matrix.col_stride; float sum = 0; for (int k_ = 0; k_ < info.k; ++k_) { - auto a_ = reinterpret_cast(a) + i * info.a_matrix.stride + m_ * info.a_matrix.row_stride + k_ * info.a_matrix.col_stride; - auto b_ = reinterpret_cast(b) + i * info.b_matrix.stride + n_ * info.b_matrix.col_stride + k_ * info.b_matrix.row_stride; - sum += f16_to_f32(*a_) * f16_to_f32(*b_); + auto a_ = reinterpret_cast(a) + i * info.a_matrix.stride + m_ * info.a_matrix.row_stride + k_ * info.a_matrix.col_stride; + auto b_ = reinterpret_cast(b) + i * info.b_matrix.stride + n_ * info.b_matrix.col_stride + k_ * info.b_matrix.row_stride; + if constexpr (std::is_same::value) { + sum += f16_to_f32(*a_) * f16_to_f32(*b_); + } else { + sum += *a_ * (*b_); + } + } + if constexpr (std::is_same::value) { + *c_ = f32_to_f16(beta * f16_to_f32(*c_) + alpha * sum); + } else { + *c_ = beta * (*c_) + alpha * sum; } - *c_ = f32_to_f16(beta * f16_to_f32(*c_) + alpha * sum); } } } + return STATUS_SUCCESS; +} + +infiniopStatus_t cpuMatmul(MatmulCpuDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *c, + void const *a, + void const *b) { + if (desc->dtype == F16) { + return matmul_cpu(desc, c, desc->beta, a, b, desc->alpha); + } + if (desc->dtype == F32) { + return matmul_cpu(desc, c, desc->beta, a, b, desc->alpha); + } + return STATUS_BAD_TENSOR_DTYPE; } diff --git a/src/ops/matmul/cpu/matmul_cpu.h b/src/ops/matmul/cpu/matmul_cpu.h index fcbd4c50..3a5970e8 100644 --- a/src/ops/matmul/cpu/matmul_cpu.h +++ b/src/ops/matmul/cpu/matmul_cpu.h @@ -34,6 +34,4 @@ infiniopStatus_t cpuMatmul(MatmulCpuDescriptor_t desc, infiniopStatus_t cpuDestroyMatmulDescriptor(MatmulCpuDescriptor_t desc); -void matmul_cpu_f16(MatmulCpuDescriptor_t desc, void *c, float beta, void const *a, void const *b, float alpha); - #endif// __CPU_MATMUL_H__ diff --git a/src/ops/matmul/cuda/matmul_cuda.cc b/src/ops/matmul/cuda/matmul_cuda.cc index 71f66cf6..8bac48d4 100644 --- a/src/ops/matmul/cuda/matmul_cuda.cc +++ b/src/ops/matmul/cuda/matmul_cuda.cc @@ -11,7 +11,7 @@ infiniopStatus_t cudaCreateMatmulDescriptor(CudaHandle_t handle, float beta) { DT dtype = c_desc->dt; - if (!dtype_eq(dtype, F16)) { + if (dtype != F16 && dtype != F32) { return STATUS_BAD_TENSOR_DTYPE; } @@ -32,21 +32,6 @@ infiniopStatus_t cudaCreateMatmulDescriptor(CudaHandle_t handle, return STATUS_SUCCESS; } -infiniopStatus_t cudaMatmul(MatmulCudaDescriptor_t desc, - void *workspace, - uint64_t workspace_size, - void *c, - void const *a, - void const *b, - void *stream) { - if (dtype_eq(desc->dtype, F16)) { - matmul_cuda_f16(desc, c, desc->beta, a, b, desc->alpha, stream); - return STATUS_SUCCESS; - } - - return STATUS_BAD_TENSOR_DTYPE; -} - infiniopStatus_t cudaGetMatmulWorkspaceSize(MatmulCudaDescriptor_t desc, uint64_t *size) { *size = 0; return STATUS_SUCCESS; diff --git a/src/ops/matmul/cuda/matmul_cuda.cu b/src/ops/matmul/cuda/matmul_cuda.cu index 32d0cf74..a75b164e 100644 --- a/src/ops/matmul/cuda/matmul_cuda.cu +++ b/src/ops/matmul/cuda/matmul_cuda.cu @@ -5,15 +5,29 @@ #include #include -void matmul_cuda_f16(MatmulCudaDescriptor_t desc, void *c, float beta, void const *a, void const *b, float alpha, void *stream) { +template +infiniopStatus_t matmul_cuda(MatmulCudaDescriptor_t desc, void *c, float beta, void const *a, void const *b, float alpha, void *stream) { auto info = desc->info; if (info.is_transed) { std::swap(a, b); } - auto alpha_f16 = __float2half(alpha); - auto beta_f16 = __float2half(beta); + Tdata alpha_, beta_; + cudaDataType a_type, b_type, c_type; + cublasComputeType_t compute_type; + + if constexpr (std::is_same::value) { + alpha_ = __float2half(alpha); + beta_ = __float2half(beta); + a_type = b_type = c_type = CUDA_R_16F; + compute_type = CUBLAS_COMPUTE_16F; + } else { + alpha_ = alpha; + beta_ = beta; + a_type = b_type = c_type = CUDA_R_32F; + compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + } auto op_a = info.a_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T; auto op_b = info.b_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T; @@ -26,21 +40,38 @@ void matmul_cuda_f16(MatmulCudaDescriptor_t desc, void *c, float beta, void cons info.m, info.n, info.k, - &alpha_f16, + &alpha_, a, - CUDA_R_16F, + a_type, info.a_matrix.ld(), info.a_matrix.stride, b, - CUDA_R_16F, + b_type, info.b_matrix.ld(), info.b_matrix.stride, - &beta_f16, + &beta_, c, - CUDA_R_16F, + c_type, info.c_matrix.ld(), info.c_matrix.stride, info.batch, - CUBLAS_COMPUTE_16F, + compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP); }); + return STATUS_SUCCESS; +} + +infiniopStatus_t cudaMatmul(MatmulCudaDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *c, + void const *a, + void const *b, + void *stream) { + if (desc->dtype == F16) { + return matmul_cuda(desc, c, desc->beta, a, b, desc->alpha, stream); + } + if (desc->dtype == F32) { + return matmul_cuda(desc, c, desc->beta, a, b, desc->alpha, stream); + } + return STATUS_BAD_TENSOR_DTYPE; } diff --git a/src/ops/matmul/cuda/matmul_cuda.h b/src/ops/matmul/cuda/matmul_cuda.h index 671ac14c..3e82c1ed 100644 --- a/src/ops/matmul/cuda/matmul_cuda.h +++ b/src/ops/matmul/cuda/matmul_cuda.h @@ -1,8 +1,8 @@ #ifndef __CUDA_MATMUL_H__ #define __CUDA_MATMUL_H__ -#include "../blas.h" #include "../../../devices/cuda/cuda_handle.h" +#include "../blas.h" #include "operators.h" #include @@ -38,6 +38,4 @@ infiniopStatus_t cudaMatmul(MatmulCudaDescriptor_t desc, infiniopStatus_t cudaDestroyMatmulDescriptor(MatmulCudaDescriptor_t desc); -void matmul_cuda_f16(MatmulCudaDescriptor_t desc, void *c, float beta, void const *a, void const *b, float alpha, void *stream); - #endif// __CUDA_MATMUL_H__ diff --git a/src/ops/utils.h b/src/ops/utils.h index fd2afcf0..ad2b65cc 100644 --- a/src/ops/utils.h +++ b/src/ops/utils.h @@ -101,6 +101,22 @@ inline bool isValidBroadcastShape(infiniopTensorDescriptor_t a, infiniopTensorDe return std::equal(broadcast_shape, broadcast_shape + broadcast_ndim, c->shape); } +// check if the shape of tensor src can be validly broadcasted to that of the tensor dst +inline bool isValidBroadcastShape(infiniopTensorDescriptor_t dst, infiniopTensorDescriptor_t src) { + if (dst->ndim < src->ndim) { + return false; + } + uint64_t padded_shape[dst->ndim]; + std::fill(padded_shape, padded_shape + dst->ndim, 1); + std::copy(src->shape, src->shape + src->ndim, padded_shape + dst->ndim - src->ndim); + for (size_t i = 0; i < dst->ndim; ++i) { + if (padded_shape[i] != dst->shape[i] && padded_shape[i] != 1) { + return false; + } + } + return true; +} + // check if the shape of tensor c is valid after broadcasting tensors a and b inline bool isValidBroadcastShape(infiniopTensorDescriptor_t a, infiniopTensorDescriptor_t b, infiniopTensorDescriptor_t c) { uint64_t broadcast_ndim = std::max(a->ndim, b->ndim);